Skip to content

migrate mnist prototype datasets #5480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 34 additions & 28 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,70 +214,76 @@ def generate(
return num_samples


# @register_mock
def mnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz"
def mnist(root, config):
prefix = "train" if config["split"] == "train" else "t10k"
return MNISTMockData.generate(
root,
num_categories=len(info.categories),
images_file=images_file,
labels_file=labels_file,
num_categories=10,
images_file=f"{prefix}-images-idx3-ubyte.gz",
labels_file=f"{prefix}-labels-idx1-ubyte.gz",
)


# DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
DATASET_MOCKS.update(
{
name: DatasetMock(name, mock_data_fn=mnist, configs=combinations_grid(split=("train", "test")))
for name in ["mnist", "fashionmnist", "kmnist"]
}
)
Comment on lines +227 to +232
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the only thing that changes is the name here. Is this going to have an effect on the actual tests that are being run, or are we running the same test 3 times? If the latter, perhaps we could register the callable with the regular regsiter_mock decorator, and just add as a comment that this also applies to fashionmnist and kmnist?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the only thing that changes is the name here.

Correct.

Is this going to have an effect on the actual tests that are being run, or are we running the same test 3 times?

We are running tests for all three datasets with the "same" mock data. If we didn't, our tests wouldn't cover the other datasets. Admittedly, the current implementation is just FashionMNIST and KMNIST subclassing MNIST and minimal parametrization, but I don't think we take much of a duration hit to include these. If that is a concern, there are other datasets that we should optimize first.



# @register_mock
def emnist(info, root, config):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there.
num_categories = defaultdict(
lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")}
@register_mock(
configs=combinations_grid(
split=("train", "test"),
image_set=("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"),
)

)
def emnist(root, config):
num_samples_map = {}
file_names = set()
for config_ in info._configs:
prefix = f"emnist-{config_.image_set.replace('_', '').lower()}-{config_.split}"
for split, image_set in itertools.product(
("train", "test"),
("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"),
):
Comment on lines +244 to +247
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinion but what was the reason for changing this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer have access to a DatasetInfo that holds all the ._configs. Thus, we need to create them manually.

prefix = f"emnist-{image_set.replace('_', '').lower()}-{split}"
images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz"
file_names.update({images_file, labels_file})
num_samples_map[config_] = MNISTMockData.generate(
num_samples_map[(split, image_set)] = MNISTMockData.generate(
root,
num_categories=num_categories[config_.image_set],
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there.
num_categories=47 if config["image_set"] in ("Balanced", "By_Merge") else 62,
images_file=images_file,
labels_file=labels_file,
)

make_zip(root, "emnist-gzip.zip", *file_names)

return num_samples_map[config]
return num_samples_map[(config["split"], config["image_set"])]


# @register_mock
def qmnist(info, root, config):
num_categories = len(info.categories)
if config.split == "train":
@register_mock(configs=combinations_grid(split=("train", "test", "test10k", "test50k", "nist")))
def qmnist(root, config):
num_categories = 10
if config["split"] == "train":
num_samples = num_samples_gen = num_categories + 2
prefix = "qmnist-train"
suffix = ".gz"
compressor = gzip.open
elif config.split.startswith("test"):
elif config["split"].startswith("test"):
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
# more than 10000 images for the dataset to not be empty.
num_samples_gen = 10001
num_samples = {
"test": num_samples_gen,
"test10k": min(num_samples_gen, 10_000),
"test50k": num_samples_gen - 10_000,
}[config.split]
}[config["split"]]
prefix = "qmnist-test"
suffix = ".gz"
compressor = gzip.open
else: # config.split == "nist"
else: # config["split"] == "nist"
num_samples = num_samples_gen = num_categories + 3
prefix = "xnist"
suffix = ".xz"
Expand Down
3 changes: 1 addition & 2 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config):
assert dp.buffer_size == INFINITE_BUFFER_SIZE


# FIXME: DATASET_MOCKS["qmnist"]
@parametrize_dataset_mocks({})
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
def test_extra_label(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
Expand Down
Loading