Skip to content

Commit

Permalink
image dataset update
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Sep 16, 2020
1 parent 05026e5 commit ea48f80
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 21 deletions.
8 changes: 5 additions & 3 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta
from torch.utils.data import Dataset, Subset
import numpy as np
from typing import Optional, Tuple, List, Any, Dict
from typing import Optional, Tuple, List, Any, Dict, Union
from autoPyTorch.datasets.cross_validation import CROSS_VAL_FN, HOLDOUT_FN, is_stratified


Expand All @@ -20,7 +20,9 @@ def type_check(train_tensors: Tuple[Any, ...], val_tensors: Optional[Tuple[Any,


class BaseDataset(Dataset, metaclass=ABCMeta):
def __init__(self, train_tensors: Tuple[Any, ...], val_tensors: Optional[Tuple[Any, ...]] = None,
def __init__(self,
train_tensors: Union[Tuple[Any, ...]],
val_tensors: Optional[Tuple[Any, ...]] = None,
shuffle: Optional[bool] = True, seed: Optional[int] = 42):
"""
:param train_tensors: A tuple of objects that have a __len__ and a __getitem__ attribute.
Expand All @@ -39,7 +41,7 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]:
return tuple(t[index] for t in self.train_tensors)

def __len__(self) -> int:
return self.train_tensors[0].shape[0]
return len(self.train_tensors[0])

def _get_indices(self) -> np.ndarray:
if self.shuffle:
Expand Down
57 changes: 39 additions & 18 deletions autoPyTorch/datasets/image_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import Dataset, TensorDataset
import torch
from PIL import Image
from autoPyTorch.datasets.base_dataset import BaseDataset
from typing import Tuple, Optional, Union, List
from autoPyTorch.datasets.cross_validation import k_fold_cross_validation, \
Expand All @@ -11,9 +13,11 @@ class ImageDataset(BaseDataset):
def __init__(self,
train: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]],
val: Optional[Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]] = None):
_type_check(train, "train")
_check_image_inputs(train=train, val=val)
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
train = _create_image_dataset(data=train)
if val is not None:
val = _create_image_dataset(data=val)
super().__init__(train_tensors=(train,), val_tensors=(val,), shuffle=True)
self.cross_validators.update(
{"k_fold_cross_validation": k_fold_cross_validation}
)
Expand All @@ -23,21 +27,38 @@ def __init__(self,
)


def _type_check(t: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]], name: str):
# TODO: finish this
if isinstance(t, Dataset):
raise ValueError("")
elif isinstance(t, tuple):
if len(t) != 2:
raise ValueError("")
pass
def _check_image_inputs(train: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]],
val: Optional[Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]]):
if not isinstance(train, Dataset):
if len(train[0]) != len(train[1]):
raise ValueError(
f"expected train inputs to have the same length, but got lengths {len(train[0])} and {len(train[1])}")
if val is not None:
if len(val[0]) != len(val[1]):
raise ValueError(
f"expected val inputs to have the same length, but got lengths {len(train[0])} and {len(train[1])}")


def _create_image_dataset(data: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]) -> Dataset:
# if user already provided a dataset, use it
if isinstance(data, Dataset):
return data
# if user provided list of file paths, create a file path dataset
if isinstance(data[0], list):
return _FilePathDataset(file_paths=data[0], targets=data[1])
# if user provided the images as numpy tensors use them directly
else:
raise TypeError(f"expected input `{name}` to be of type torch.utils.data.Dataset, {type(Dataset)}")
return TensorDataset(torch.tensor(data[0]), torch.tensor(data[1]))


def _check_image_inputs(train: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]],
val: Optional[Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]]):
# TODO: finish this
_type_check(train, "train")
if val is not None:
_type_check(val, "val")
class _FilePathDataset(Dataset):
def __init__(self, file_paths: List[str], targets: np.ndarray):
self.file_paths = file_paths
self.targets = targets

def __getitem__(self, index: int):
img = Image.open(self.file_paths[index]).convert("RGB")
return img, torch.tensor(self.targets[index])

def __len__(self) -> int:
return len(self.file_paths)

0 comments on commit ea48f80

Please sign in to comment.