-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataloader.py
92 lines (69 loc) · 2.75 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from typing import Optional, Callable
import os
from glob import glob
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset
from torchvision import transforms
from utils import fft2d, zero_padding_twice
__DATASETS__ = {}
def register_dataset(name: str):
def wrapper(cls):
if __DATASETS__.get(name, None):
raise NameError(f"Name {name} is already registered.")
__DATASETS__[name] = cls
return cls
return wrapper
def get_dataset(name: str):
if __DATASETS__.get(name, None) is None:
raise NameError(f"Name {name} is not defined.")
return __DATASETS__[name]
def get_valid_loader(dataset_name: str, root: str, batch_size: int, **kwargs):
transform = transforms.Compose([
transforms.ToTensor()])
dataset = get_dataset(dataset_name)(root, False, transform=transform, **kwargs)
data_loader = DataLoader(dataset, batch_size)
return data_loader
@register_dataset(name='png_dataset')
class PNGDataset(VisionDataset):
def __init__(self, root: str, train:bool, transform):
super().__init__(root=root)
self.transform = transform
stage = "train" if train else "valid"
self.image_paths = sorted(glob(os.path.join(root, stage, '*.png')))
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index: int):
img_path = self.image_paths[index]
image = Image.open(fp=img_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image
@register_dataset(name='amplitude_no_pad_dataset')
class AmplitudeNoPadDataset(PNGDataset):
def __getitem__(self, index):
image = super().__getitem__(index)
fft_image = fft2d(image)
amplitude = fft_image.abs()
return image, amplitude, image
@register_dataset(name='amplitude_dataset')
class AmplitudeDataset(PNGDataset):
def __getitem__(self, index):
image = super().__getitem__(index)
support = torch.ones_like(image)
image = zero_padding_twice(image)
support = zero_padding_twice(support)
fft_image = fft2d(image)
amplitude = fft_image.abs()
return image, amplitude, support
@register_dataset(name='noise_amplitude_dataset')
class NoiseAmplitudeDataset(AmplitudeDataset):
def __init__(self, root: str, train: bool, sigma:float, transform:Optional[Callable]=None):
super().__init__(root, train, transform)
self.sigma = sigma
def __getitem__(self, index):
image, amplitude, support = super().__getitem__(index)
noise = torch.randn(amplitude.shape) * self.sigma
amplitude += noise
return image, amplitude, support