Skip to content

Commit 01977a1

Browse files
committed
fix root type annotations for remaining datasets
1 parent a728ed7 commit 01977a1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+127
-99
lines changed

torchvision/datasets/_optical_flow.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .utils import _read_pfm, verify_str_arg
1414
from .vision import VisionDataset
1515

16-
1716
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
1817
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
1918

@@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset):
3332
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
3433
_has_builtin_flow_mask = False
3534

36-
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
35+
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
3736

3837
super().__init__(root=root)
3938
self.transforms = transforms
@@ -125,7 +124,7 @@ class Sintel(FlowDataset):
125124

126125
def __init__(
127126
self,
128-
root: str,
127+
root: Union[str, Path],
129128
split: str = "train",
130129
pass_name: str = "clean",
131130
transforms: Optional[Callable] = None,
@@ -191,7 +190,7 @@ class KittiFlow(FlowDataset):
191190

192191
_has_builtin_flow_mask = True
193192

194-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
193+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
195194
super().__init__(root=root, transforms=transforms)
196195

197196
verify_str_arg(split, "split", valid_values=("train", "test"))
@@ -256,7 +255,7 @@ class FlyingChairs(FlowDataset):
256255
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
257256
"""
258257

259-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
258+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
260259
super().__init__(root=root, transforms=transforms)
261260

262261
verify_str_arg(split, "split", valid_values=("train", "val"))
@@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset):
329328

330329
def __init__(
331330
self,
332-
root: str,
331+
root: Union[str, Path],
333332
split: str = "train",
334333
pass_name: str = "clean",
335334
camera: str = "left",
@@ -419,7 +418,7 @@ class HD1K(FlowDataset):
419418

420419
_has_builtin_flow_mask = True
421420

422-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
421+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
423422
super().__init__(root=root, transforms=transforms)
424423

425424
verify_str_arg(split, "split", valid_values=("train", "test"))

torchvision/datasets/_stereo_matching.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
2727

2828
_has_built_in_disparity_mask = False
2929

30-
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
30+
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
3131
"""
3232
Args:
3333
root(str): Root directory of the dataset.
@@ -163,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset):
163163
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
164164
"""
165165

166-
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
166+
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
167167
super().__init__(root, transforms)
168168

169169
root = Path(root) / "carla-highres"
@@ -240,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
240240

241241
_has_built_in_disparity_mask = True
242242

243-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
243+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
244244
super().__init__(root, transforms)
245245

246246
verify_str_arg(split, "split", valid_values=("train", "test"))
@@ -328,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
328328

329329
_has_built_in_disparity_mask = True
330330

331-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
331+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
332332
super().__init__(root, transforms)
333333

334334
verify_str_arg(split, "split", valid_values=("train", "test"))
@@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
480480

481481
def __init__(
482482
self,
483-
root: str,
483+
root: Union[str, Path],
484484
split: str = "train",
485485
calibration: Optional[str] = "perfect",
486486
use_ambient_views: bool = False,
@@ -576,7 +576,7 @@ def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.n
576576
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
577577
return disparity_map, valid_mask
578578

579-
def _download_dataset(self, root: str) -> None:
579+
def _download_dataset(self, root: Union[str, Path]) -> None:
580580
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
581581
# train and additional splits have 2 different calibration settings
582582
root = Path(root) / "Middlebury2014"
@@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset):
675675

676676
def __init__(
677677
self,
678-
root: str,
678+
root: Union[str, Path],
679679
transforms: Optional[Callable] = None,
680680
) -> None:
681681
super().__init__(root, transforms)
@@ -762,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset):
762762
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
763763
"""
764764

765-
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
765+
def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
766766
super().__init__(root, transforms)
767767

768768
root = Path(root) / "FallingThings"
@@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset):
877877

878878
def __init__(
879879
self,
880-
root: str,
880+
root: Union[str, Path],
881881
variant: str = "FlyingThings3D",
882882
pass_name: str = "clean",
883883
transforms: Optional[Callable] = None,
@@ -980,7 +980,7 @@ class SintelStereo(StereoMatchingDataset):
980980

981981
_has_built_in_disparity_mask = True
982982

983-
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
983+
def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
984984
super().__init__(root, transforms)
985985

986986
verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
@@ -1087,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset):
10871087
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
10881088
"""
10891089

1090-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
1090+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
10911091
super().__init__(root, transforms)
10921092

10931093
root = Path(root) / "InStereo2k" / split
@@ -1176,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset):
11761176

11771177
_has_built_in_disparity_mask = True
11781178

1179-
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
1179+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
11801180
super().__init__(root, transforms)
11811181

11821182
verify_str_arg(split, "split", valid_values=("train", "test"))

torchvision/datasets/caltech.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import os.path
3+
from pathlib import Path
34
from typing import Any, Callable, List, Optional, Tuple, Union
45

56
from PIL import Image
@@ -38,7 +39,7 @@ class Caltech101(VisionDataset):
3839

3940
def __init__(
4041
self,
41-
root: str,
42+
root: Union[str, Path],
4243
target_type: Union[List[str], str] = "category",
4344
transform: Optional[Callable] = None,
4445
target_transform: Optional[Callable] = None,

torchvision/datasets/celeba.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import csv
22
import os
33
from collections import namedtuple
4+
from pathlib import Path
45
from typing import Any, Callable, List, Optional, Tuple, Union
56

67
import PIL
@@ -63,7 +64,7 @@ class CelebA(VisionDataset):
6364

6465
def __init__(
6566
self,
66-
root: str,
67+
root: Union[str, Path],
6768
split: str = "train",
6869
target_type: Union[List[str], str] = "attr",
6970
transform: Optional[Callable] = None,

torchvision/datasets/cifar.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os.path
22
import pickle
3-
from typing import Any, Callable, Optional, Tuple
3+
from pathlib import Path
4+
from typing import Any, Callable, Optional, Tuple, Union
45

56
import numpy as np
67
from PIL import Image
@@ -50,7 +51,7 @@ class CIFAR10(VisionDataset):
5051

5152
def __init__(
5253
self,
53-
root: str,
54+
root: Union[str, Path],
5455
train: bool = True,
5556
transform: Optional[Callable] = None,
5657
target_transform: Optional[Callable] = None,

torchvision/datasets/cityscapes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
from collections import namedtuple
4+
from pathlib import Path
45
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
56

67
from PIL import Image
@@ -103,7 +104,7 @@ class Cityscapes(VisionDataset):
103104

104105
def __init__(
105106
self,
106-
root: str,
107+
root: Union[str, Path],
107108
split: str = "train",
108109
mode: str = "fine",
109110
target_type: Union[List[str], str] = "instance",

torchvision/datasets/clevr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import pathlib
3-
from typing import Any, Callable, List, Optional, Tuple
3+
from typing import Any, Callable, List, Optional, Tuple, Union
44
from urllib.parse import urlparse
55

66
from PIL import Image
@@ -30,7 +30,7 @@ class CLEVRClassification(VisionDataset):
3030

3131
def __init__(
3232
self,
33-
root: str,
33+
root: Union[str, pathlib.Path],
3434
split: str = "train",
3535
transform: Optional[Callable] = None,
3636
target_transform: Optional[Callable] = None,

torchvision/datasets/coco.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os.path
2-
from typing import Any, Callable, List, Optional, Tuple
2+
from pathlib import Path
3+
from typing import Any, Callable, List, Optional, Tuple, Union
34

45
from PIL import Image
56

@@ -24,7 +25,7 @@ class CocoDetection(VisionDataset):
2425

2526
def __init__(
2627
self,
27-
root: str,
28+
root: Union[str, Path],
2829
annFile: str,
2930
transform: Optional[Callable] = None,
3031
target_transform: Optional[Callable] = None,

torchvision/datasets/country211.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Callable, Optional
2+
from typing import Callable, Optional, Union
33

44
from .folder import ImageFolder
55
from .utils import download_and_extract_archive, verify_str_arg
@@ -28,7 +28,7 @@ class Country211(ImageFolder):
2828

2929
def __init__(
3030
self,
31-
root: str,
31+
root: Union[str, Path],
3232
split: str = "train",
3333
transform: Optional[Callable] = None,
3434
target_transform: Optional[Callable] = None,

torchvision/datasets/dtd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import pathlib
3-
from typing import Any, Callable, Optional, Tuple
3+
from typing import Any, Callable, Optional, Tuple, Union
44

55
import PIL.Image
66

@@ -34,7 +34,7 @@ class DTD(VisionDataset):
3434

3535
def __init__(
3636
self,
37-
root: str,
37+
root: Union[str, pathlib.Path],
3838
split: str = "train",
3939
partition: int = 1,
4040
transform: Optional[Callable] = None,

0 commit comments

Comments
 (0)