Skip to content

Add GTSRB dataset to prototypes #5214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
237a707
Change default of download for Food101 and DTD
NicolasHug Jan 5, 2022
bc3be4e
WIP
NicolasHug Jan 7, 2022
85ca229
Merge branch 'main' of github.com:pytorch/vision into defaultdownload
NicolasHug Jan 18, 2022
87695d4
Set download default to False and put it at the end
NicolasHug Jan 18, 2022
1e6e37d
Keep stuff private
NicolasHug Jan 18, 2022
474546f
GTSRB: train -> split. Also use pathlib
NicolasHug Jan 18, 2022
a38a18b
mypy
NicolasHug Jan 18, 2022
d58ef16
Remove split and partition for SUN397
NicolasHug Jan 18, 2022
5061141
mypy
NicolasHug Jan 18, 2022
6c02cff
mypy
NicolasHug Jan 18, 2022
d3cb34f
Merge branch 'main' of github.com:pytorch/vision into gtsrb_prototype
NicolasHug Jan 18, 2022
d288c6c
Merge branch 'defaultdownload' into gtsrb_prototype
NicolasHug Jan 18, 2022
521b75c
WIP
NicolasHug Jan 18, 2022
1c1ceb0
WIP
NicolasHug Jan 19, 2022
1b2ee27
Merge branch 'main' of github.com:pytorch/vision into gtsrb_prototype
NicolasHug Jan 19, 2022
4fdb976
WIP
NicolasHug Jan 19, 2022
a6ae4c4
Add tests
NicolasHug Jan 19, 2022
761e5d7
Add some types
NicolasHug Jan 19, 2022
1dd6efe
lmao mypy you funny lad
NicolasHug Jan 19, 2022
a32ab88
fix unpacking
NicolasHug Jan 19, 2022
862187a
Merge branch 'main' of github.com:pytorch/vision into gtsrb_prototype
NicolasHug Jan 19, 2022
e487828
Use DictWriter
NicolasHug Jan 19, 2022
8f15cc3
Hardcode categories since they are just ints in [0, 42]
NicolasHug Jan 19, 2022
9ac22d3
Split URL root
NicolasHug Jan 19, 2022
1f1fa35
Use name instead of stem
NicolasHug Jan 19, 2022
f25a83a
Add category to labels, and fix dict reading
NicolasHug Jan 19, 2022
52ec648
Use path_comparator
NicolasHug Jan 19, 2022
379876f
Use buffer_size=1
NicolasHug Jan 19, 2022
632c212
Merge branch 'main' of github.com:pytorch/vision into gtsrb_prototype
NicolasHug Jan 20, 2022
0d6b58d
Merge branch 'main' of github.com:pytorch/vision into gtsrb_prototype
NicolasHug Jan 20, 2022
e26b456
Use Zipper instead of IterKeyZipper
NicolasHug Jan 20, 2022
b958b6b
mypy
NicolasHug Jan 20, 2022
06c0904
Some more instructions
NicolasHug Jan 20, 2022
18b87e2
forgot backquotes
NicolasHug Jan 20, 2022
44bb8f1
Apply suggestions from code review
NicolasHug Jan 21, 2022
c1ec16d
gt -> ground_truth
NicolasHug Jan 21, 2022
ff78c70
e -> sample
NicolasHug Jan 21, 2022
cd38e25
Add support for bboxes
NicolasHug Jan 21, 2022
1e8aea6
Update torchvision/prototype/datasets/_builtin/gtsrb.py
NicolasHug Jan 21, 2022
8e9a617
format
NicolasHug Jan 21, 2022
6703710
Remove unused method
NicolasHug Jan 21, 2022
6b67ce7
Add test for label matching
NicolasHug Jan 21, 2022
1ef84e0
Update test/test_prototype_builtin_datasets.py
NicolasHug Jan 24, 2022
8283332
Merge branch 'main' into gtsrb_prototype
NicolasHug Jan 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,76 @@ def fer2013(info, root, config):
return num_samples


@DATASET_MOCKS.set_from_named_callable
def gtsrb(info, root, config):
num_examples_per_class = 5 if config.split == "train" else 3
classes = ("00000", "00042", "00012")
num_examples = num_examples_per_class * len(classes)

csv_columns = ["Filename", "Width", "Height", "Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2", "ClassId"]

def _make_ann_file(path, num_examples, class_idx):
if class_idx == "random":
class_idx = torch.randint(1, len(classes) + 1, size=(1,)).item()

with open(path, "w") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=csv_columns, delimiter=";")
writer.writeheader()
for image_idx in range(num_examples):
writer.writerow(
{
"Filename": f"{image_idx:05d}.ppm",
"Width": torch.randint(1, 100, size=()).item(),
"Height": torch.randint(1, 100, size=()).item(),
"Roi.X1": torch.randint(1, 100, size=()).item(),
"Roi.Y1": torch.randint(1, 100, size=()).item(),
"Roi.X2": torch.randint(1, 100, size=()).item(),
"Roi.Y2": torch.randint(1, 100, size=()).item(),
"ClassId": class_idx,
}
)

if config["split"] == "train":
train_folder = root / "GTSRB" / "Training"
train_folder.mkdir(parents=True)

for class_idx in classes:
create_image_folder(
train_folder,
name=class_idx,
file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm",
num_examples=num_examples_per_class,
)
_make_ann_file(
path=train_folder / class_idx / f"GT-{class_idx}.csv",
num_examples=num_examples_per_class,
class_idx=int(class_idx),
)
make_zip(root, "GTSRB-Training_fixed.zip", train_folder)
else:
test_folder = root / "GTSRB" / "Final_Test"
test_folder.mkdir(parents=True)

create_image_folder(
test_folder,
name="Images",
file_name_fn=lambda image_idx: f"{image_idx:05d}.ppm",
num_examples=num_examples,
)

make_zip(root, "GTSRB_Final_Test_Images.zip", test_folder)

_make_ann_file(
path=root / "GT-final_test.csv",
num_examples=num_examples,
class_idx="random",
)

make_zip(root, "GTSRB_Final_Test_GT.zip", "GT-final_test.csv")

return num_examples


@DATASET_MOCKS.set_from_named_callable
def clevr(info, root, config):
data_folder = root / "CLEVR_v1.0"
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def scan(graph):
if type(dp) is annotation_dp_type:
break
else:
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")


@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
Expand Down
166 changes: 134 additions & 32 deletions torchvision/prototype/datasets/_builtin/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# How to add new built-in prototype datasets

As the name implies, the datasets are still in a prototype state and thus subject to rapid change. This in turn means that this document will also change a lot.
As the name implies, the datasets are still in a prototype state and thus
subject to rapid change. This in turn means that this document will also change
a lot.

If you hit a blocker while adding a dataset, please have a look at another similar dataset to see how it is implemented there. If you can't resolve it yourself, feel free to send a draft PR in order for us to help you out.
If you hit a blocker while adding a dataset, please have a look at another
similar dataset to see how it is implemented there. If you can't resolve it
yourself, feel free to send a draft PR in order for us to help you out.

Finally, `from torchvision.prototype import datasets` is implied below.

## Implementation

Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin` that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be discussed in detail below:
Before we start with the actual implementation, you should create a module in
`torchvision/prototype/datasets/_builtin` that hints at the dataset you are
going to add. For example `caltech.py` for `caltech101` and `caltech256`. In
that module create a class that inherits from `datasets.utils.Dataset` and
overwrites at minimum three methods that will be discussed in detail below:

```python
import io
Expand Down Expand Up @@ -37,27 +45,54 @@ class MyDataset(Dataset):

### `_make_info(self)`

The `DatasetInfo` carries static information about the dataset. There are two required fields:
- `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain lower characters.
- `type`: Field of the `datasets.utils.DatasetType` enum. This is used to select the default decoder in case the user doesn't pass one. There are currently only two options: `IMAGE` and `RAW` ([see below](what-is-the-datasettyperaw-and-when-do-i-use-it) for details).
The `DatasetInfo` carries static information about the dataset. There are two
required fields:
- `name`: Name of the dataset. This will be used to load the dataset with
`datasets.load(name)`. Should only contain lowercase characters.
- `type`: Field of the `datasets.utils.DatasetType` enum. This is used to select
the default decoder in case the user doesn't pass one. There are currently
only two options: `IMAGE` and `RAW` ([see
below](what-is-the-datasettyperaw-and-when-do-i-use-it) for details).

There are more optional parameters that can be passed:

- `dependencies`: Collection of third-party dependencies that are needed to load the dataset, e.g. `("scipy",)`. Their availability will be automatically checked if a user tries to load the dataset. Within the implementation, import these packages lazily to avoid missing dependencies at import time.
- `categories`: Sequence of human-readable category names for each label. The index of each category has to match the corresponding label returned in the dataset samples. [See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories.
- `valid_options`: Configures valid options that can be passed to the dataset. It should be `Dict[str, Sequence[str]]`. The options are accessible through the `config` namespace in the other two functions. First value of the sequence is taken as default if the user passes no option to `torchvision.prototype.datasets.load()`.
- `dependencies`: Collection of third-party dependencies that are needed to load
the dataset, e.g. `("scipy",)`. Their availability will be automatically
checked if a user tries to load the dataset. Within the implementation, import
these packages lazily to avoid missing dependencies at import time.
- `categories`: Sequence of human-readable category names for each label. The
index of each category has to match the corresponding label returned in the
dataset samples. [See
below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle
cases with many categories.
- `valid_options`: Configures valid options that can be passed to the dataset.
It should be `Dict[str, Sequence[Any]]`. The options are accessible through
the `config` namespace in the other two functions. First value of the sequence
is taken as default if the user passes no option to
`torchvision.prototype.datasets.load()`.

## `resources(self, config)`

Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset with a specific `config` can be build. The download will happen automatically.
Returns `List[datasets.utils.OnlineResource]` of all the files that need to be
present locally before the dataset with a specific `config` can be build. The
download will happen automatically.

Currently, the following `OnlineResource`'s are supported:

- `HttpResource`: Used for files that are directly exposed through HTTP(s) and only requires the URL.
- `GDriveResource`: Used for files that are hosted on GDrive and requires the GDrive ID as well as the `file_name`.
- `ManualDownloadResource`: Used files are not publicly accessible and requires instructions how to download them manually. If the file does not exist, an error will be raised with the supplied instructions.

Although optional in general, all resources used in the built-in datasets should comprise [SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It will be automatically checked after the download. You can compute the checksum with system utilities or this snippet:
- `HttpResource`: Used for files that are directly exposed through HTTP(s) and
only requires the URL.
- `GDriveResource`: Used for files that are hosted on GDrive and requires the
GDrive ID as well as the `file_name`.
- `ManualDownloadResource`: Used files are not publicly accessible and requires
instructions how to download them manually. If the file does not exist, an
error will be raised with the supplied instructions.
- `KaggleDownloadResource`: Used for files that are available on Kaggle. This
inherits from `ManualDownloadResource`.

Although optional in general, all resources used in the built-in datasets should
comprise [SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It
will be automatically checked after the download. You can compute the checksum
with system utilities e.g `sha256-sum`, or this snippet:

```python
import hashlib
Expand All @@ -72,35 +107,84 @@ def sha256sum(path, chunk_size=1024 * 1024):

### `_make_datapipe(resource_dps, *, config, decoder)`

This method is the heart of the dataset that need to transform the raw data into a usable form. A major difference compared to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone that is working with them rather than on them, `IterDataPipe`'s behave just as generators, i.e. you can't do anything with them besides iterating.
This method is the heart of the dataset, where we transform the raw data into
a usable form. A major difference compared to the current stable datasets is
that everything is performed through `IterDataPipe`'s. From the perspective of
someone that is working with them rather than on them, `IterDataPipe`'s behave
just as generators, i.e. you can't do anything with them besides iterating.

Of course, there are some common building blocks that should suffice in 95% of the cases. The most used
Of course, there are some common building blocks that should suffice in 95% of
the cases. The most used are:

- `Mapper`: Apply a callable to every item in the datapipe.
- `Filter`: Keep only items that satisfy a condition.
- `Demultiplexer`: Split a datapipe into multiple ones.
- `IterKeyZipper`: Merge two datapipes into one.

All of them can be imported `from torchdata.datapipes.iter`. In addition, use `functools.partial` in case a callable needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated to add one. See the MNIST or CelebA datasets for example.
All of them can be imported `from torchdata.datapipes.iter`. In addition, use
`functools.partial` in case a callable needs extra arguments. If the provided
`IterDataPipe`'s are not sufficient for the use case, it is also not complicated
to add one. See the MNIST or CelebA datasets for example.

`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has
a 1-to-1 correspondence with the return value of `resources()`. In case of
archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain
tuples comprised of the path and the handle for every file in the archive.
Otherwise the datapipe will only contain one of such tuples for the file
specified by the resource.

Since the datapipes are iterable in nature, some datapipes feature an in-memory
buffer, e.g. `IterKeyZipper` and `Grouper`. There are two issues with that: 1.
If not used carefully, this can easily overflow the host memory, since most
datasets will not fit in completely. 2. This can lead to unnecessarily long
warm-up times when data is buffered that is only needed at runtime.

Thus, all buffered datapipes should be used as early as possible, e.g. zipping
two datapipes of file handles rather than trying to zip already loaded images.

There are two special datapipes that are not used through their class, but
through the functions `hint_sharding` and `hint_shuffling`. As the name implies
they only hint part in the datapipe graph where sharding and shuffling should
take place, but are no-ops by default. They can be imported from
`torchvision.prototype.datasets.utils._internal` and are required in each
dataset.

Finally, each item in the final datapipe should be a dictionary with `str` keys.
There is no standardization of the names (yet!).

`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return value of `resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain tuples comprised of the path and the handle for every file in the archive. Otherwise the datapipe will only contain one of such tuples for the file specified by the resource.
## FAQ

Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and `Grouper`. There are two issues with that:
1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely.
2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime.
### How do I start?

Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than trying to zip already loaded images.
Get the skeleton of your dataset class ready with all 3 methods. For
`_make_datapipe()`, you can just do `return resources_dp[0]` to get started.
Then import the dataset class in
`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically
register the dataset and it will be instantiable via
`datasets.load("mydataset")`. On a separate script, try something like

There are two special datapipes that are not used through their class, but through the functions `hint_sharding` and `hint_shuffling`. As the name implies they only hint part in the datapipe graph where sharding and shuffling should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are required in each dataset.
```py
from torchvision.prototype import datasets

Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the names (yet!).
dataset = datasets.load("mydataset")
for sample in dataset:
print(sample) # this is the content of an item in datapipe returned by _make_datapipe()
break
# Or you can also inspect the sample in a debugger
```

## FAQ
This will give you an idea of what the first datapipe in `resources_dp`
contains. You can also do that with `resources_dp[1]` or `resources_dp[2]`
(etc.) if they exist. Then follow the instructions above to manipulate these
datapipes and return the appropriate dictionary format.

### What is the `DatasetType.RAW` and when do I use it?

`DatasetType.RAW` marks dataset that provides decoded, i.e. raw pixel values, rather than encoded image files such as
`.jpg` or `.png`. This is usually only the case for small datasets, since it requires a lot more disk space. The default decoder `datasets.decoder.raw` is only a sentinel and should not be called directly. The decoding should look something like
`DatasetType.RAW` marks dataset that provides decoded, i.e. raw pixel values,
rather than encoded image files such as `.jpg` or `.png`. This is usually only
the case for small datasets, since it requires a lot more disk space. The
default decoder `datasets.decoder.raw` is only a sentinel and should not be
called directly. The decoding should look something like

```python
from torchvision.prototype.datasets.decoder import raw
Expand All @@ -118,10 +202,28 @@ For examples, have a look at the MNIST, CIFAR, or SEMEION datasets.

### How do I handle a dataset that defines many categories?

As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only be set directly for ten categories or fewer. If more categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a category. If `$NAME` matches the name of the dataset (which it definitively should!) it will be automatically loaded if `categories=` is not set.

In case the categories can be generated from the dataset files, e.g. the dataset follow an image folder approach where each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. It gets passed the `root` path to the resources, but they have to be manually loaded, e.g. `self.resources(config)[0].load(root)`. The method should return a sequence of strings representing the category names. To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`.
As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only
be set directly for ten categories or fewer. If more categories are needed, you
can add a `$NAME.categories` file to the `_builtin` folder in which each line
specifies a category. If `$NAME` matches the name of the dataset (which it
definitively should!) it will be automatically loaded if `categories=` is not
set.

In case the categories can be generated from the dataset files, e.g. the dataset
follows an image folder approach where each folder denotes the name of the
category, the dataset can overwrite the `_generate_categories` method. It gets
passed the `root` path to the resources, but they have to be manually loaded,
e.g. `self.resources(config)[0].load(root)`. The method should return a sequence
of strings representing the category names. To generate the `$NAME.categories`
file, run `python -m torchvision.prototype.datasets.generate_category_files
$NAME`.

### What if a resource file forms an I/O bottleneck?

In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if the performance hit becomes significant, the archives can still be decompressed or extracted. To do this, the `decompress: bool` and `extract: bool` flags can be used for every `OnlineResource` individually. For more complex cases, each resource also accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw file and should return `pathlib.Path` of the preprocessed file or folder.
In general, we are ok with small performance hits of iterating archives rather
than their extracted content. However, if the performance hit becomes
significant, the archives can still be decompressed or extracted. To do this,
the `decompress: bool` and `extract: bool` flags can be used for every
`OnlineResource` individually. For more complex cases, each resource also
accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw
file and should return `pathlib.Path` of the preprocessed file or folder.
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .cub200 import CUB200
from .dtd import DTD
from .fer2013 import FER2013
from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet
Expand Down
Loading