diff --git a/torch_uncertainty/datamodules/tiny_imagenet.py b/torch_uncertainty/datamodules/tiny_imagenet.py index cfbf0f3b..beebe077 100644 --- a/torch_uncertainty/datamodules/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/tiny_imagenet.py @@ -1,4 +1,3 @@ -# fmt: off from argparse import ArgumentParser from pathlib import Path from typing import Any, List, Optional, Union @@ -7,10 +6,10 @@ from pytorch_lightning import LightningDataModule from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader, Dataset -from torchvision.datasets import SVHN +from torch.utils.data import ConcatDataset, DataLoader, Dataset +from torchvision.datasets import DTD, SVHN -from ..datasets.classification import TinyImageNet +from ..datasets.classification import ImageNetO, TinyImageNet # fmt: on @@ -24,6 +23,7 @@ def __init__( root: Union[str, Path], ood_detection: bool, batch_size: int, + ood_ds: str = "svhn", rand_augment_opt: Optional[str] = None, num_workers: int = 1, pin_memory: bool = True, @@ -41,8 +41,16 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers + self.ood_ds = ood_ds + self.dataset = TinyImageNet - self.ood_dataset = SVHN + + if ood_ds == "imagenet-o": + self.ood_dataset = ImageNetO + elif ood_ds == "svhn": + self.ood_dataset = SVHN + elif ood_ds == "textures": + self.ood_dataset = DTD if rand_augment_opt is not None: main_transform = rand_augment_transform(rand_augment_opt, {}) @@ -61,7 +69,8 @@ def __init__( self.transform_test = T.Compose( [ - T.Resize(64), + T.Resize(72), + T.CenterCrop(64), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] @@ -76,7 +85,36 @@ def _verify_splits(self, split: str) -> None: def prepare_data(self) -> None: # coverage: ignore if self.ood_detection: - self.ood_dataset(self.root, split="test", download=True) + if self.ood_ds != "textures": + self.ood_dataset( + self.root, + split="test", + download=True, + transform=self.transform_test, + ) + else: + ConcatDataset( + [ + self.ood_dataset( + self.root, + split="train", + download=True, + transform=self.transform_test, + ), + self.ood_dataset( + self.root, + split="val", + download=True, + transform=self.transform_test, + ), + self.ood_dataset( + self.root, + split="test", + download=True, + transform=self.transform_test, + ), + ] + ) def setup(self, stage: Optional[str] = None) -> None: if stage == "fit" or stage is None: @@ -98,11 +136,35 @@ def setup(self, stage: Optional[str] = None) -> None: ) if self.ood_detection: - self.ood = self.ood_dataset( - self.root, - split="test", - transform=self.transform_test, - ) + if self.ood_ds == "textures": + self.ood = ConcatDataset( + [ + self.ood_dataset( + self.root, + split="train", + download=True, + transform=self.transform_test, + ), + self.ood_dataset( + self.root, + split="val", + download=True, + transform=self.transform_test, + ), + self.ood_dataset( + self.root, + split="test", + download=True, + transform=self.transform_test, + ), + ] + ) + else: + self.ood = self.ood_dataset( + self.root, + split="test", + transform=self.transform_test, + ) def train_dataloader(self) -> DataLoader: r"""Get the training dataloader for TinyImageNet. diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 1b66727e..025d45ff 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -5,6 +5,7 @@ from typing import Any, Optional, Tuple, Union from pytorch_lightning import LightningDataModule +from torch import Generator from torch.utils.data import DataLoader, Dataset, random_split from ..datasets.regression import UCIRegression @@ -43,6 +44,7 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = True, input_shape: Optional[Tuple[int, ...]] = None, + split_seed: int = 42, **kwargs, ) -> None: super().__init__() @@ -57,8 +59,11 @@ def __init__( self.pin_memory = pin_memory self.persistent_workers = persistent_workers - self.dataset = partial(UCIRegression, dataset_name=dataset_name) + self.dataset = partial( + UCIRegression, dataset_name=dataset_name, seed=split_seed + ) self.input_shape = input_shape + self.gen = Generator().manual_seed(split_seed) def prepare_data(self) -> None: """Download the dataset.""" @@ -79,6 +84,7 @@ def setup(self, stage: Optional[str] = None) -> None: - int(len(full) * 0.2) - int(len(full) * (0.8 - self.val_split)), ], + generator=self.gen, ) if self.val_split == 0: self.val = self.test