-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy path__init__.py
39 lines (35 loc) · 1.44 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
'''create dataset and dataloader'''
import logging
from re import split
import torch.utils.data
def create_dataloader(dataset, dataset_opt, phase):
'''create dataloader '''
if phase == 'train':
return torch.utils.data.DataLoader(
dataset,
batch_size=dataset_opt['batch_size'],
shuffle=dataset_opt['use_shuffle'],
num_workers=dataset_opt['num_workers'],
pin_memory=True)
elif phase == 'val':
return torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
else:
raise NotImplementedError(
'Dataloader [{:s}] is not found.'.format(phase))
def create_dataset(dataset_opt, phase):
'''create dataset'''
mode = dataset_opt['mode']
from data.LRHR_dataset import LRHRDataset as D
dataset = D(dataroot=dataset_opt['dataroot'],
datatype=dataset_opt['datatype'],
l_resolution=dataset_opt['l_resolution'],
r_resolution=dataset_opt['r_resolution'],
split=phase,
data_len=dataset_opt['data_len'],
need_LR=(mode == 'LRHR')
)
logger = logging.getLogger('base')
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
dataset_opt['name']))
return dataset