-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
47 lines (31 loc) · 1.16 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize
from dataset import *
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)
crop_size = 256
def transform():
global crop_size
return Compose([
CenterCrop(crop_size),
ToTensor(),
])
def get_training_set():
root_dir = 'datasets/train/'
LR_dir = join(root_dir, "LR")
HR_dir = join(root_dir, "HR")
return DatasetFromFolder(LR_dir, HR_dir, input_transform=transform(),
target_transform=transform())
def get_test_set():
root_dir = 'datasets/test/set14/'
LR_dir = join(root_dir, "LR")
HR_dir = join(root_dir, "HR")
return DatasetFromFolder(LR_dir, HR_dir,
input_transform=transform(),
target_transform=transform())
def get_valid_set():
root_dir = 'datasets/valid/'
LR_dir = join(root_dir, "LR")
HR_dir = join(root_dir, "HR")
return DatasetFromFolder(LR_dir, HR_dir,
input_transform=transform(),
target_transform=transform())