Skip to content

Commit

Permalink
Document that datasets support pathlib.Path (#8321)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
NicolasHug and pmeier authored Mar 18, 2024
1 parent 0325175 commit 2ba586d
Show file tree
Hide file tree
Showing 46 changed files with 216 additions and 181 deletions.
23 changes: 11 additions & 12 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset


T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]

Expand All @@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False

def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:

super().__init__(root=root)
self.transforms = transforms
Expand Down Expand Up @@ -113,7 +112,7 @@ class Sintel(FlowDataset):
...
Args:
root (string): Root directory of the Sintel Dataset.
root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes.
Expand All @@ -125,7 +124,7 @@ class Sintel(FlowDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
Expand Down Expand Up @@ -183,15 +182,15 @@ class KittiFlow(FlowDataset):
flow_occ
Args:
root (string): Root directory of the KittiFlow Dataset.
root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -248,15 +247,15 @@ class FlyingChairs(FlowDataset):
Args:
root (string): Root directory of the FlyingChairs Dataset.
root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "val"))
Expand Down Expand Up @@ -316,7 +315,7 @@ class FlyingThings3D(FlowDataset):
TRAIN
Args:
root (string): Root directory of the intel FlyingThings3D Dataset.
root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
details on the different passes.
Expand All @@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
pass_name: str = "clean",
camera: str = "left",
Expand Down Expand Up @@ -411,15 +410,15 @@ class HD1K(FlowDataset):
image_2
Args:
root (string): Root directory of the HD1K Dataset.
root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down
42 changes: 21 additions & 21 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):

_has_built_in_disparity_mask = False

def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
"""
Args:
root(str): Root directory of the dataset.
Expand Down Expand Up @@ -159,11 +159,11 @@ class CarlaStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory where `carla-highres` is located.
root (str or ``pathlib.Path``): Root directory where `carla-highres` is located.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "carla-highres"
Expand Down Expand Up @@ -233,14 +233,14 @@ class Kitti2012Stereo(StereoMatchingDataset):
calib
Args:
root (string): Root directory where `Kitti2012` is located.
root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -321,14 +321,14 @@ class Kitti2015Stereo(StereoMatchingDataset):
calib
Args:
root (string): Root directory where `Kitti2015` is located.
root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -420,7 +420,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
...
Args:
root (string): Root directory of the Middleburry 2014 Dataset.
root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset.
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
Expand Down Expand Up @@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
calibration: Optional[str] = "perfect",
use_ambient_views: bool = False,
Expand Down Expand Up @@ -576,7 +576,7 @@ def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.n
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask

def _download_dataset(self, root: str) -> None:
def _download_dataset(self, root: Union[str, Path]) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014"
Expand Down Expand Up @@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root, transforms)
Expand Down Expand Up @@ -757,12 +757,12 @@ class FallingThingsStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory where FallingThings is located.
root (str or ``pathlib.Path``): Root directory where FallingThings is located.
variant (string): Which variant to use. Either "single", "mixed", or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "FallingThings"
Expand Down Expand Up @@ -868,7 +868,7 @@ class SceneFlowStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory where SceneFlow is located.
root (str or ``pathlib.Path``): Root directory where SceneFlow is located.
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
Expand All @@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
variant: str = "FlyingThings3D",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
Expand Down Expand Up @@ -973,14 +973,14 @@ class SintelStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory where Sintel Stereo is located.
root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located.
pass_name (string): The name of the pass to use, either "final", "clean" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
Expand Down Expand Up @@ -1082,12 +1082,12 @@ class InStereo2k(StereoMatchingDataset):
...
Args:
root (string): Root directory where InStereo2k is located.
root (str or ``pathlib.Path``): Root directory where InStereo2k is located.
split (string): Either "train" or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "InStereo2k" / split
Expand Down Expand Up @@ -1169,14 +1169,14 @@ class ETH3DStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory of the ETH3D Dataset.
root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down
7 changes: 4 additions & 3 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import os.path
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

from PIL import Image
Expand All @@ -16,7 +17,7 @@ class Caltech101(VisionDataset):
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified
Expand All @@ -38,7 +39,7 @@ class Caltech101(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
target_type: Union[List[str], str] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down Expand Up @@ -153,7 +154,7 @@ class Caltech256(VisionDataset):
"""`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``caltech256`` exists or will be saved to if download is set to True.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
Expand Down
5 changes: 3 additions & 2 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

import PIL
Expand All @@ -16,7 +17,7 @@ class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
split (string): One of {'train', 'valid', 'test', 'all'}.
Accordingly dataset is selected.
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
Expand Down Expand Up @@ -63,7 +64,7 @@ class CelebA(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None,
Expand Down
7 changes: 4 additions & 3 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path
import pickle
from typing import Any, Callable, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

import numpy as np
from PIL import Image
Expand All @@ -13,7 +14,7 @@ class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
Expand Down Expand Up @@ -50,7 +51,7 @@ class CIFAR10(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
5 changes: 3 additions & 2 deletions torchvision/datasets/cityscapes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from PIL import Image
Expand All @@ -13,7 +14,7 @@ class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
Expand Down Expand Up @@ -103,7 +104,7 @@ class Cityscapes(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
mode: str = "fine",
target_type: Union[List[str], str] = "instance",
Expand Down
Loading

0 comments on commit 2ba586d

Please sign in to comment.