forked from mapleneverfade/pytorch-semantic-segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
85 lines (67 loc) · 2.87 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset
EXTENSIONS = ['.jpg', '.png','.JPG','.PNG']
def load_image(file):
return Image.open(file)
def is_image(filename):
return any(filename.endswith(ext) for ext in EXTENSIONS)
def image_path(root, basename, extension):
return os.path.join(root, '{}{}'.format(basename,extension))
def image_path_city(root, name):
return os.path.join(root, '{}'.format(name))
def image_basename(filename):
return os.path.basename(os.path.splitext(filename)[0])
class NeoData(Dataset):
def __init__(self, imagepath=None, labelpath=None, transform=None):
# make sure label match with image
self.transform = transform
assert os.path.exists(imagepath), "{} not exists !".format(imagepath)
assert os.path.exists(labelpath), "{} not exists !".format(labelpath)
self.image = []
self.label= []
with open(imagepath,'r') as f:
for line in f:
self.image.append(line.strip())
with open(labelpath,'r') as f:
for line in f:
self.label.append(line.strip())
def __getitem__(self, index):
filename = self.image[index]
filenameGt = self.label[index]
with open(filename, 'rb') as f:
image = load_image(f).convert('RGB')
with open(filenameGt, 'rb') as f:
label = load_image(f).convert('P')
if self.transform is not None:
image, label = self.transform(image, label)
return image, label
def __len__(self):
return len(self.image)
class NeoData_test(Dataset):
def __init__(self, imagepath=None, labelpath=None, transform=None):
self.transform = transform
assert os.path.exists(imagepath), "{} not exists !".format(imagepath)
assert os.path.exists(labelpath), "{} not exists !".format(labelpath)
self.image = []
self.label= []
with open(imagepath,'r') as f:
for line in f:
self.image.append(line.strip())
with open(labelpath,'r') as f:
for line in f:
self.label.append(line.strip())
print("Length of test data is {}".format(len(self.image)))
def __getitem__(self, index):
filename = self.image[index]
filenameGt = self.label[index]
with open(filename, 'rb') as f: # advance
image = load_image(f).convert('RGB')
with open(filenameGt, 'rb') as f:
label = load_image(f).convert('P')
if self.transform is not None:
image_tensor, label_tensor, img = self.transform(image, label)
return (image_tensor, label_tensor, np.array(img)) #return original image, in order to show segmented area in origin
def __len__(self):
return len(self.image)