Skip to content

Commit

Permalink
Migrate Fer2013 prototype dataset (#5759)
Browse files Browse the repository at this point in the history
* Migrate Fer2013 prototype dataset

* Update torchvision/prototype/datasets/_builtin/fer2013.py

Co-authored-by: Philip Meier <github.pmeier@posteo.de>

Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
NicolasHug and pmeier authored Apr 6, 2022
1 parent 6de6ec4 commit ebe9006
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
13 changes: 7 additions & 6 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,13 +1019,14 @@ def dtd(root, config):
return num_samples_map[config["split"], config["fold"]]


# @register_mock
def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3
@register_mock(configs=combinations_grid(split=("train", "test")))
def fer2013(root, config):
split = config["split"]
num_samples = 5 if split == "train" else 3

path = root / f"{config.split}.csv"
path = root / f"{split}.csv"
with open(path, "w", newline="") as file:
field_names = ["emotion"] if config.split == "train" else []
field_names = ["emotion"] if split == "train" else []
field_names.append("pixels")

file.write(",".join(field_names) + "\n")
Expand All @@ -1035,7 +1036,7 @@ def fer2013(info, root, config):
rowdict = {
"pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)])
}
if config.split == "train":
if split == "train":
rowdict["emotion"] = int(torch.randint(7, ()))
writer.writerow(rowdict)

Expand Down
57 changes: 34 additions & 23 deletions torchvision/prototype/datasets/_builtin/fer2013.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, Dict, List, cast
import pathlib
from typing import Any, Dict, List, cast, Union

import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
Dataset2,
OnlineResource,
KaggleDownloadResource,
)
Expand All @@ -15,26 +14,40 @@
)
from torchvision.prototype.features import Label, Image

from .._api import register_dataset, register_info

NAME = "fer2013"


@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"))

class FER2013(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fer2013",
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"),
valid_options=dict(split=("train", "test")),
)

@register_dataset(NAME)
class FER2013(Dataset2):
"""FER 2013 Dataset
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
"""

def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]

super().__init__(root, skip_integrity_check=skip_integrity_check)

_CHECKSUMS = {
"train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10",
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
archive = KaggleDownloadResource(
cast(str, self.info.homepage),
file_name=f"{config.split}.csv.zip",
sha256=self._CHECKSUMS[config.split],
"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
file_name=f"{self._split}.csv.zip",
sha256=self._CHECKSUMS[self._split],
)
return [archive]

Expand All @@ -43,17 +56,15 @@ def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:

return dict(
image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)),
label=Label(int(label_id), categories=self.categories) if label_id is not None else None,
label=Label(int(label_id), categories=self._categories) if label_id is not None else None,
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVDictParser(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)

def __len__(self) -> int:
return 28_709 if self._split == "train" else 3_589

0 comments on commit ebe9006

Please sign in to comment.