Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def write_version_file():
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")

requirements = [
"typing_extensions",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"numpy",
"requests",
pytorch_dep,
Expand Down
4 changes: 3 additions & 1 deletion torchvision/datasets/stl10.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os.path
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, cast

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -65,10 +65,12 @@ def __init__(
self.labels: Optional[np.ndarray]
if self.split == "train":
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.labels = cast(np.ndarray, self.labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have self.labels: Optional[np.ndarray] above, should we remove one of these?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that. Let me have a look.

self.__load_folds(folds)

elif self.split == "train+unlabeled":
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.labels = cast(np.ndarray, self.labels)
self.__load_folds(folds)
unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
self.data = np.concatenate((self.data, unlabeled_data))
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
from PIL import Image, ImageOps, ImageEnhance
from typing_extensions import Literal

try:
import accimage
Expand Down Expand Up @@ -130,7 +131,7 @@ def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: str = "constant",
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a new one for me, Literally.

) -> Image.Image:

if not _is_pil_image(img):
Expand Down Expand Up @@ -189,7 +190,7 @@ def pad(
if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img
Expand Down