Skip to content

Commit

Permalink
✨ Refine TinyImageNet and UCIRegression DMs
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Oct 11, 2023
1 parent ea2e56b commit cffea9c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
86 changes: 74 additions & 12 deletions torch_uncertainty/datamodules/tiny_imagenet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# fmt: off
from argparse import ArgumentParser
from pathlib import Path
from typing import Any, List, Optional, Union
Expand All @@ -7,10 +6,10 @@
from pytorch_lightning import LightningDataModule
from timm.data.auto_augment import rand_augment_transform
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import SVHN
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from torchvision.datasets import DTD, SVHN

from ..datasets.classification import TinyImageNet
from ..datasets.classification import ImageNetO, TinyImageNet


# fmt: on
Expand All @@ -24,6 +23,7 @@ def __init__(
root: Union[str, Path],
ood_detection: bool,
batch_size: int,
ood_ds: str = "svhn",
rand_augment_opt: Optional[str] = None,
num_workers: int = 1,
pin_memory: bool = True,
Expand All @@ -41,8 +41,16 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.ood_ds = ood_ds

self.dataset = TinyImageNet
self.ood_dataset = SVHN

if ood_ds == "imagenet-o":
self.ood_dataset = ImageNetO
elif ood_ds == "svhn":
self.ood_dataset = SVHN
elif ood_ds == "textures":
self.ood_dataset = DTD

if rand_augment_opt is not None:
main_transform = rand_augment_transform(rand_augment_opt, {})
Expand All @@ -61,7 +69,8 @@ def __init__(

self.transform_test = T.Compose(
[
T.Resize(64),
T.Resize(72),

This comment has been minimized.

Copy link
@o-laurent

o-laurent Oct 26, 2023

Author Contributor

This weirdly has a HUGE impact on perfs.

T.CenterCrop(64),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
Expand All @@ -76,7 +85,36 @@ def _verify_splits(self, split: str) -> None:

def prepare_data(self) -> None: # coverage: ignore
if self.ood_detection:
self.ood_dataset(self.root, split="test", download=True)
if self.ood_ds != "textures":
self.ood_dataset(
self.root,
split="test",
download=True,
transform=self.transform_test,
)
else:
ConcatDataset(
[
self.ood_dataset(
self.root,
split="train",
download=True,
transform=self.transform_test,
),
self.ood_dataset(
self.root,
split="val",
download=True,
transform=self.transform_test,
),
self.ood_dataset(
self.root,
split="test",
download=True,
transform=self.transform_test,
),
]
)

def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
Expand All @@ -98,11 +136,35 @@ def setup(self, stage: Optional[str] = None) -> None:
)

if self.ood_detection:
self.ood = self.ood_dataset(
self.root,
split="test",
transform=self.transform_test,
)
if self.ood_ds == "textures":
self.ood = ConcatDataset(
[
self.ood_dataset(
self.root,
split="train",
download=True,
transform=self.transform_test,
),
self.ood_dataset(
self.root,
split="val",
download=True,
transform=self.transform_test,
),
self.ood_dataset(
self.root,
split="test",
download=True,
transform=self.transform_test,
),
]
)
else:
self.ood = self.ood_dataset(
self.root,
split="test",
transform=self.transform_test,
)

def train_dataloader(self) -> DataLoader:
r"""Get the training dataloader for TinyImageNet.
Expand Down
8 changes: 7 additions & 1 deletion torch_uncertainty/datamodules/uci_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Optional, Tuple, Union

from pytorch_lightning import LightningDataModule
from torch import Generator
from torch.utils.data import DataLoader, Dataset, random_split

from ..datasets.regression import UCIRegression
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(
pin_memory: bool = True,
persistent_workers: bool = True,
input_shape: Optional[Tuple[int, ...]] = None,
split_seed: int = 42,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -57,8 +59,11 @@ def __init__(
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers

self.dataset = partial(UCIRegression, dataset_name=dataset_name)
self.dataset = partial(
UCIRegression, dataset_name=dataset_name, seed=split_seed
)
self.input_shape = input_shape
self.gen = Generator().manual_seed(split_seed)

def prepare_data(self) -> None:
"""Download the dataset."""
Expand All @@ -79,6 +84,7 @@ def setup(self, stage: Optional[str] = None) -> None:
- int(len(full) * 0.2)
- int(len(full) * (0.8 - self.val_split)),
],
generator=self.gen,
)
if self.val_split == 0:
self.val = self.test
Expand Down

0 comments on commit cffea9c

Please sign in to comment.