-
Notifications
You must be signed in to change notification settings - Fork 7.1k
migrate mnist prototype datasets #5480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -214,70 +214,76 @@ def generate( | |
return num_samples | ||
|
||
|
||
# @register_mock | ||
def mnist(info, root, config): | ||
train = config.split == "train" | ||
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" | ||
labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz" | ||
def mnist(root, config): | ||
prefix = "train" if config["split"] == "train" else "t10k" | ||
return MNISTMockData.generate( | ||
root, | ||
num_categories=len(info.categories), | ||
images_file=images_file, | ||
labels_file=labels_file, | ||
num_categories=10, | ||
images_file=f"{prefix}-images-idx3-ubyte.gz", | ||
labels_file=f"{prefix}-labels-idx1-ubyte.gz", | ||
) | ||
|
||
|
||
# DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) | ||
DATASET_MOCKS.update( | ||
{ | ||
name: DatasetMock(name, mock_data_fn=mnist, configs=combinations_grid(split=("train", "test"))) | ||
for name in ["mnist", "fashionmnist", "kmnist"] | ||
} | ||
) | ||
|
||
|
||
# @register_mock | ||
def emnist(info, root, config): | ||
# The image sets that merge some lower case letters in their respective upper case variant, still use dense | ||
# labels in the data files. Thus, num_categories != len(categories) there. | ||
num_categories = defaultdict( | ||
lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")} | ||
@register_mock( | ||
configs=combinations_grid( | ||
split=("train", "test"), | ||
image_set=("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), | ||
) | ||
|
||
) | ||
def emnist(root, config): | ||
num_samples_map = {} | ||
file_names = set() | ||
for config_ in info._configs: | ||
prefix = f"emnist-{config_.image_set.replace('_', '').lower()}-{config_.split}" | ||
for split, image_set in itertools.product( | ||
("train", "test"), | ||
("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), | ||
): | ||
Comment on lines
+244
to
+247
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No strong opinion but what was the reason for changing this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We no longer have access to a |
||
prefix = f"emnist-{image_set.replace('_', '').lower()}-{split}" | ||
images_file = f"{prefix}-images-idx3-ubyte.gz" | ||
labels_file = f"{prefix}-labels-idx1-ubyte.gz" | ||
file_names.update({images_file, labels_file}) | ||
num_samples_map[config_] = MNISTMockData.generate( | ||
num_samples_map[(split, image_set)] = MNISTMockData.generate( | ||
root, | ||
num_categories=num_categories[config_.image_set], | ||
# The image sets that merge some lower case letters in their respective upper case variant, still use dense | ||
# labels in the data files. Thus, num_categories != len(categories) there. | ||
num_categories=47 if config["image_set"] in ("Balanced", "By_Merge") else 62, | ||
images_file=images_file, | ||
labels_file=labels_file, | ||
) | ||
|
||
make_zip(root, "emnist-gzip.zip", *file_names) | ||
|
||
return num_samples_map[config] | ||
return num_samples_map[(config["split"], config["image_set"])] | ||
|
||
|
||
# @register_mock | ||
def qmnist(info, root, config): | ||
num_categories = len(info.categories) | ||
if config.split == "train": | ||
@register_mock(configs=combinations_grid(split=("train", "test", "test10k", "test50k", "nist"))) | ||
def qmnist(root, config): | ||
num_categories = 10 | ||
if config["split"] == "train": | ||
num_samples = num_samples_gen = num_categories + 2 | ||
prefix = "qmnist-train" | ||
suffix = ".gz" | ||
compressor = gzip.open | ||
elif config.split.startswith("test"): | ||
elif config["split"].startswith("test"): | ||
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create | ||
# more than 10000 images for the dataset to not be empty. | ||
num_samples_gen = 10001 | ||
num_samples = { | ||
"test": num_samples_gen, | ||
"test10k": min(num_samples_gen, 10_000), | ||
"test50k": num_samples_gen - 10_000, | ||
}[config.split] | ||
}[config["split"]] | ||
prefix = "qmnist-test" | ||
suffix = ".gz" | ||
compressor = gzip.open | ||
else: # config.split == "nist" | ||
else: # config["split"] == "nist" | ||
num_samples = num_samples_gen = num_categories + 3 | ||
prefix = "xnist" | ||
suffix = ".xz" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC the only thing that changes is the
name
here. Is this going to have an effect on the actual tests that are being run, or are we running the same test 3 times? If the latter, perhaps we could register the callable with the regularregsiter_mock
decorator, and just add as a comment that this also applies to fashionmnist and kmnist?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct.
We are running tests for all three datasets with the "same" mock data. If we didn't, our tests wouldn't cover the other datasets. Admittedly, the current implementation is just FashionMNIST and
KMNIST
subclassingMNIST
and minimal parametrization, but I don't think we take much of a duration hit to include these. If that is a concern, there are other datasets that we should optimize first.