Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor slice sampler #329

Merged
merged 14 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install coveralls

- name: Install dev packages
run: |
pip install ".[dev]"

- name: Install TF package
run: |
pip install tensorflow==${{ matrix.tf-version }}
# Fix proto dep issue in protobuf 4
pip install protobuf==3.20.*

- name: Lint with flake8
run: |
Expand Down
164 changes: 164 additions & 0 deletions tensorflow_similarity/samplers/tfdata_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from __future__ import annotations

from collections.abc import Callable, Sequence

import tensorflow as tf


def create_grouped_dataset(
ds: tf.data.Dataset,
class_list: Sequence[int] | None = None,
total_examples: int | None = None,
buffer_size: int | None = None,
) -> list[tf.data.Dataset]:
"""
Creates a list of datasets grouped by class id.

Args:
ds: A `tf.data.Dataset` object.
class_list: An optional `Sequence` of integers representing the classes
to include in the dataset. If `None`, all classes are included.
total_examples: An integer representing the maximum number of examples
to include in the dataset. If `None`, all examples are included.
buffer_size: An optional integer representing the size of the buffer
for shuffling. Default is None.

Returns:
A List of `tf.data.Dataset` objects grouped by class id.
"""
window_size = ds.cardinality().numpy()
grouped_by_cid = ds.group_by_window(
key_func=lambda x, y, *args: y,
reduce_func=lambda key, ds: ds.batch(window_size),
window_size=window_size,
)

cid_datasets = []
for elem in grouped_by_cid.as_numpy_iterator():
# This assumes that elem has at least (x, y) and that y is a tensor or array_like.
if class_list is not None and elem[1][0] not in class_list:
continue
if total_examples is not None:
elem = elem[:total_examples]
cid_ds = tf.data.Dataset.from_tensor_slices(elem)
if buffer_size is not None:
cid_ds = cid_ds.shuffle(buffer_size)
cid_datasets.append(cid_ds.repeat())

return cid_datasets


def create_choices_dataset(num_classes: int, examples_per_class: int) -> tf.data.Dataset:
"""
Creates a dataset that generates random integers between 0 and `num_classes`.
Integers will be generated in contiguous blocks of size `examples_per_class`.
Integers are sampled without replacement and are not selected again until all
other interger values have been sampled.

Args:
num_classes: An integer representing the total number of classes in the dataset.
examples_per_class: An integer representing the number of examples per class.

Returns:
A `tf.data.Dataset` object representing the dataset with random choices.
"""
return (
tf.data.Dataset.range(num_classes)
.shuffle(num_classes)
.map(lambda z: [[z] * examples_per_class], name="repeat_cid")
.flat_map(tf.data.Dataset.from_tensor_slices)
.repeat()
)


def apply_augmenter_ds(ds: tf.data.Dataset, augmenter: Callable, warmup: int | None = None) -> tf.data.Dataset:
"""
Applies an augmenter function to a dataset batch and optionally delays
applying the function for `warmup` number of batches.

Args:
ds: A `tf.data.Dataset` object.
augmenter: A callable function used to apply data augmentation to
individual examples within each batch. If `None`, no data
augmentation is applied.
warmup: An optional integer representing the number of batches to wait
before applying the data augmentation function. If `None`, no
warmup is applied.

Returns:
A `tf.data.Dataset` object with the applied augmenter.
"""
if warmup is None:
return ds.map(augmenter, name="augmenter")

aug_ds = ds.map(augmenter, name="augmenter").skip(warmup)
tf_version_split = tf.__version__.split(".")
if int(tf_version_split[0]) >= 2 and int(tf_version_split[1]) >= 10:
count_ds = tf.data.Dataset.counter()
else:
count_ds = tf.data.experimental.Counter()

ds = tf.data.Dataset.choose_from_datasets(
[ds, aug_ds],
count_ds.map(lambda x: tf.cast(0, dtype=tf.dtypes.int64) if x < warmup else tf.cast(1, dtype=tf.dtypes.int64)),
)

return ds


def TFDataSampler(
ds: tf.data.Dataset,
classes_per_batch: int = 2,
examples_per_class_per_batch: int = 2,
class_list: Sequence[int] | None = None,
total_examples_per_class: int | None = None,
augmenter: Callable | None = None,
load_fn: Callable | None = None,
warmup: int | None = None,
) -> tf.data.Dataset:
"""
Returns a `tf.data.Dataset` object that generates batches of examples with
equal number of examples per class. The input dataset cardinality must be
finite and known.

Args:
ds: A `tf.data.Dataset` object representing the original dataset.
classes_per_batch: An integer specifying the number of classes per batch.
examples_per_class_per_batch: An integer specifying the number of examples
per class per batch.
class_list: An optional sequence of integers representing the class IDs
to include in the dataset. If `None`, all classes in the original
dataset will be used.
total_examples_per_class: An optional integer representing the total
number of examples per class to use. If `None`, all examples for
each class will be used.
augmenter: An optional function to apply data augmentation to each
example in a batch.
load_fn: An optional callable function that loads examples from disk.
warmup: An optional integer specifying the number of batches to use for
unaugmented warmup. If `None`, no warmup will be used.

Returns:
A `tf.data.Dataset` object representing the balanced dataset for few-shot learning tasks.
"""
if ds.cardinality() == tf.data.INFINITE_CARDINALITY:
raise ValueError("Dataset must be finite.")
if ds.cardinality() == tf.data.UNKNOWN_CARDINALITY:
raise ValueError("Dataset cardinality must be known.")

grouped_dataset = create_grouped_dataset(ds, class_list, total_examples_per_class)
choices_ds = create_choices_dataset(len(grouped_dataset), examples_per_class_per_batch)

batch_size = examples_per_class_per_batch * classes_per_batch

ds = tf.data.Dataset.choose_from_datasets(grouped_dataset, choices_ds).repeat().batch(batch_size)

if load_fn is not None:
ds = ds.map(load_fn, name="load_example_fn")

if augmenter is not None:
ds = apply_augmenter_ds(ds, augmenter, warmup)

ds = ds.prefetch(tf.data.AUTOTUNE)

return ds
180 changes: 180 additions & 0 deletions tests/samplers/test_tfdata_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from collections import defaultdict

import tensorflow as tf

from tensorflow_similarity.samplers import tfdata_sampler as tfds


class TestCreateGroupedDataset(tf.test.TestCase):
def setUp(self):
self.ds = tf.data.Dataset.from_tensor_slices(
(
tf.constant([1, 2, 3, 4, 5, 6, 7, 8], dtype=tf.float32),
tf.constant([1, 1, 2, 2, 3, 3, 4, 4], dtype=tf.int64),
)
)

def test_returns_correct_number_of_datasets(self):
cid_datasets = tfds.create_grouped_dataset(self.ds)
self.assertLen(cid_datasets, 4)

def test_returns_correct_number_of_datasets_with_class_list(self):
cid_datasets = tfds.create_grouped_dataset(self.ds, class_list=[1, 2])
self.assertLen(cid_datasets, 2)

def test_returns_correct_number_of_datasets_with_total_examples(self):
cid_datasets = tfds.create_grouped_dataset(self.ds, total_examples=1)
self.assertLen(cid_datasets, 4)

# test that each cid dataset has only 1 example that is repeated.
for elem in cid_datasets:
self.assertAllEqual(
list(elem.take(1).as_numpy_iterator()),
list(elem.take(1).as_numpy_iterator()),
)

def test_returns_correct_number_of_datasets_with_buffer_size(self):
cid_datasets = tfds.create_grouped_dataset(self.ds, buffer_size=2)
self.assertEqual(len(cid_datasets), 4)
owenvallis marked this conversation as resolved.
Show resolved Hide resolved

def test_datasets_repeat(self):
cid_datasets = tfds.create_grouped_dataset(self.ds)
for cid_ds in cid_datasets:
self.assertTrue(cid_ds.element_spec, self.ds.element_spec)

# Check that repeating groups of 2 elements all equal each other.
elements = list(cid_ds.take(6).as_numpy_iterator())
self.assertAllEqual(elements[:2], elements[2:4])
self.assertAllEqual(elements[:2], elements[4:])

def test_datasets_shuffled(self):
cid_datasets = tfds.create_grouped_dataset(self.ds, buffer_size=4)
# check that including the buffer shuffles the values in each cid_ds.
for cid_ds in cid_datasets:
self.assertNotEqual(
list(cid_ds.take(20).as_numpy_iterator()),
list(cid_ds.take(20).as_numpy_iterator()),
)


class TestCreateChoicesDataset(tf.test.TestCase):
def test_sample_without_replacement(self):
# Test that each class appears exactly examples_per_class times
num_classes = 5
examples_per_class = 2
dataset = tfds.create_choices_dataset(num_classes, examples_per_class)
elements = list(dataset.take(num_classes * examples_per_class).as_numpy_iterator())
unique_elements = set(elements)
self.assertLen(unique_elements, num_classes)
owenvallis marked this conversation as resolved.
Show resolved Hide resolved

def test_dataset_values(self):
# Test that the dataset only contains values between 0 and num_classes
num_classes = 10
examples_per_class = 3
dataset = tfds.create_choices_dataset(num_classes, examples_per_class)
for x in dataset.take(num_classes * examples_per_class).as_numpy_iterator():
self.assertGreaterEqual(x, 0)
self.assertLess(x, num_classes)

def test_dataset_repetition(self):
# Test that each class appears exactly examples_per_class times
num_classes = 4
examples_per_class = 2
num_repeats = 2
dataset = tfds.create_choices_dataset(num_classes, examples_per_class)
class_counts = defaultdict(int)
for x in dataset.take(num_classes * examples_per_class * num_repeats).as_numpy_iterator():
class_counts[x] += 1
for count in class_counts.values():
self.assertEqual(count, examples_per_class * num_repeats)


def dummy_augmenter(x):
return x + 10


class TestAugmenter(tf.test.TestCase):
def setUp(self):
self.ds = tf.data.Dataset.range(10).batch(2)

def test_apply_augmentation_no_warmup(self):
augmented_ds = tfds.apply_augmenter_ds(self.ds, dummy_augmenter)

for x in augmented_ds:
self.assertListEqual(x.numpy().tolist(), [10, 11])
break

def test_apply_augmentation_with_warmup(self):
warmup = 1
augmented_ds = tfds.apply_augmenter_ds(self.ds, dummy_augmenter, warmup)

for i, x in enumerate(augmented_ds):
if i < warmup:
self.assertListEqual(x.numpy().tolist(), [0, 1])
else:
self.assertListEqual(x.numpy().tolist(), [12, 13])
break


class TestTFDataSampler(tf.test.TestCase):
def setUp(self):
self.ds = tf.data.Dataset.from_tensor_slices(
(
tf.random.uniform((6, 2), dtype=tf.float32),
tf.constant([1, 1, 1, 2, 2, 2], dtype=tf.int64),
)
)
self.expected_elementspec = (
tf.TensorSpec(shape=(None, 2), dtype=tf.float32, name=None),
tf.TensorSpec(shape=(None,), dtype=tf.int64, name=None),
)

def test_cardinality_is_finite(self):
# Test that an exception is raised when the input dataset is infinite
ds = tf.data.Dataset.from_tensors([1]).repeat()
with self.assertRaisesWithLiteralMatch(ValueError, "Dataset must be finite."):
tfds.TFDataSampler(ds)

def test_cardinality_is_known(self):
# Test that an exception is raised when the input dataset has unknown cardinality
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).filter(lambda x: x > 1)
with self.assertRaisesWithLiteralMatch(ValueError, "Dataset cardinality must be known."):
tfds.TFDataSampler(ds)

def test_output_batch_size(self):
# Test that the output batch size is correct
out_ds = tfds.TFDataSampler(self.ds)
self.assertEqual(out_ds.element_spec, self.expected_elementspec)

def test_output_classes_per_batch(self):
# Test that the number of classes per batch is correct
out_ds = tfds.TFDataSampler(self.ds, classes_per_batch=1)
self.assertEqual(out_ds.element_spec, self.expected_elementspec)

def test_output_examples_per_class_per_batch(self):
# Test that the number of examples per class per batch is correct
out_ds = tfds.TFDataSampler(self.ds, examples_per_class_per_batch=1)
self.assertEqual(out_ds.element_spec, self.expected_elementspec)

def test_output_class_list(self):
# Test that the class list is correctly used
out_ds = tfds.TFDataSampler(self.ds, class_list=[1])
self.assertEqual(out_ds.element_spec, self.expected_elementspec)

def test_output_total_examples_per_class(self):
# Test that the total number of examples per class is correctly used
out_ds = tfds.TFDataSampler(self.ds, total_examples_per_class=2)
self.assertEqual(out_ds.element_spec, self.expected_elementspec)

def test_output_augmenter(self):
# Test that the augmenter is correctly applied
out_ds = tfds.TFDataSampler(self.ds, augmenter=lambda x, y: (x * 2, y))
self.assertEqual(out_ds.element_spec, self.expected_elementspec)
owenvallis marked this conversation as resolved.
Show resolved Hide resolved

def test_output_load_fn(self):
# TODO(ovallis): Test that the load_fn is correctly applied
pass
owenvallis marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
tf.test.main()