diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index a9a5448..f03801d 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -1,7 +1,7 @@ from abc import ABCMeta from torch.utils.data import Dataset, Subset import numpy as np -from typing import Optional, Tuple, List, Any, Dict +from typing import Optional, Tuple, List, Any, Dict, Union from autoPyTorch.datasets.cross_validation import CROSS_VAL_FN, HOLDOUT_FN, is_stratified @@ -20,7 +20,9 @@ def type_check(train_tensors: Tuple[Any, ...], val_tensors: Optional[Tuple[Any, class BaseDataset(Dataset, metaclass=ABCMeta): - def __init__(self, train_tensors: Tuple[Any, ...], val_tensors: Optional[Tuple[Any, ...]] = None, + def __init__(self, + train_tensors: Union[Tuple[Any, ...]], + val_tensors: Optional[Tuple[Any, ...]] = None, shuffle: Optional[bool] = True, seed: Optional[int] = 42): """ :param train_tensors: A tuple of objects that have a __len__ and a __getitem__ attribute. @@ -39,7 +41,7 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]: return tuple(t[index] for t in self.train_tensors) def __len__(self) -> int: - return self.train_tensors[0].shape[0] + return len(self.train_tensors[0]) def _get_indices(self) -> np.ndarray: if self.shuffle: diff --git a/autoPyTorch/datasets/image_dataset.py b/autoPyTorch/datasets/image_dataset.py index e07c82c..a453372 100644 --- a/autoPyTorch/datasets/image_dataset.py +++ b/autoPyTorch/datasets/image_dataset.py @@ -1,5 +1,7 @@ import numpy as np -from torch.utils.data import Dataset +from torch.utils.data import Dataset, TensorDataset +import torch +from PIL import Image from autoPyTorch.datasets.base_dataset import BaseDataset from typing import Tuple, Optional, Union, List from autoPyTorch.datasets.cross_validation import k_fold_cross_validation, \ @@ -11,9 +13,11 @@ class ImageDataset(BaseDataset): def __init__(self, train: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]], val: Optional[Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]] = None): - _type_check(train, "train") _check_image_inputs(train=train, val=val) - super().__init__(train_tensors=train, val_tensors=val, shuffle=True) + train = _create_image_dataset(data=train) + if val is not None: + val = _create_image_dataset(data=val) + super().__init__(train_tensors=(train,), val_tensors=(val,), shuffle=True) self.cross_validators.update( {"k_fold_cross_validation": k_fold_cross_validation} ) @@ -23,21 +27,38 @@ def __init__(self, ) -def _type_check(t: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]], name: str): - # TODO: finish this - if isinstance(t, Dataset): - raise ValueError("") - elif isinstance(t, tuple): - if len(t) != 2: - raise ValueError("") - pass +def _check_image_inputs(train: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]], + val: Optional[Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]]): + if not isinstance(train, Dataset): + if len(train[0]) != len(train[1]): + raise ValueError( + f"expected train inputs to have the same length, but got lengths {len(train[0])} and {len(train[1])}") + if val is not None: + if len(val[0]) != len(val[1]): + raise ValueError( + f"expected val inputs to have the same length, but got lengths {len(train[0])} and {len(train[1])}") + + +def _create_image_dataset(data: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]) -> Dataset: + # if user already provided a dataset, use it + if isinstance(data, Dataset): + return data + # if user provided list of file paths, create a file path dataset + if isinstance(data[0], list): + return _FilePathDataset(file_paths=data[0], targets=data[1]) + # if user provided the images as numpy tensors use them directly else: - raise TypeError(f"expected input `{name}` to be of type torch.utils.data.Dataset, {type(Dataset)}") + return TensorDataset(torch.tensor(data[0]), torch.tensor(data[1])) -def _check_image_inputs(train: Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]], - val: Optional[Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]]): - # TODO: finish this - _type_check(train, "train") - if val is not None: - _type_check(val, "val") +class _FilePathDataset(Dataset): + def __init__(self, file_paths: List[str], targets: np.ndarray): + self.file_paths = file_paths + self.targets = targets + + def __getitem__(self, index: int): + img = Image.open(self.file_paths[index]).convert("RGB") + return img, torch.tensor(self.targets[index]) + + def __len__(self) -> int: + return len(self.file_paths)