-
Notifications
You must be signed in to change notification settings - Fork 85
/
dataloader.py
65 lines (55 loc) · 1.59 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python
import sys
sys.path.append("..")
from utils import cv2_trans as transforms
from termcolor import cprint
import cv2
import torchvision
import torch.utils.data as data
import torch
import random
import numpy as np
import os
import warnings
class MagTrainDataset(data.Dataset):
def __init__(self, ann_file, transform=None):
self.ann_file = ann_file
self.transform = transform
self.init()
def init(self):
self.weight = {}
self.im_names = []
self.targets = []
self.pre_types = []
with open(self.ann_file) as f:
for line in f.readlines():
data = line.strip().split(' ')
self.im_names.append(data[0])
self.targets.append(int(data[2]))
def __getitem__(self, index):
im_name = self.im_names[index]
target = self.targets[index]
img = cv2.imread(im_name)
img = self.transform(img)
return img, target
def __len__(self):
return len(self.im_names)
def train_loader(args):
train_trans = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
train_dataset = MagTrainDataset(
args.train_list,
transform=train_trans
)
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset,
shuffle=(train_sampler is None),
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
sampler=train_sampler,
drop_last=(train_sampler is None))
return train_loader