-
Notifications
You must be signed in to change notification settings - Fork 36
/
data_parser.py
65 lines (54 loc) · 2.19 KB
/
data_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import csv
from collections import namedtuple
ListDataJpeg = namedtuple('ListDataJpeg', ['id', 'label', 'path'])
ListDataJpeg_test = namedtuple('ListDataJpeg', ['id', 'path'])
class JpegDataset(object):
def __init__(self, csv_path_input, csv_path_labels, data_root):
self.classes = self.read_csv_labels(csv_path_labels)
self.classes_dict = self.get_two_way_dict(self.classes)
self.csv_data = self.read_csv_input(csv_path_input, data_root)
def read_csv_input(self, csv_path, data_root):
csv_data = []
with open(csv_path) as csvfile:
csv_reader = csv.reader(csvfile, delimiter=';')
for row in csv_reader:
item = ListDataJpeg(row[0],
row[1],
os.path.join(data_root, row[0])
)
if row[1] in self.classes:
csv_data.append(item)
return csv_data
def read_csv_labels(self, csv_path):
classes = []
with open(csv_path) as csvfile:
csv_reader = csv.reader(csvfile)
for row in csv_reader:
classes.append(row[0])
return classes
def get_two_way_dict(self, classes):
classes_dict = {}
for i, item in enumerate(classes):
classes_dict[item] = i
classes_dict[i] = item
return classes_dict
class JpegDataset_test(object):
def __init__(self, csv_path_input, data_root):
self.csv_data = self.read_csv_input(csv_path_input, data_root)
def read_csv_input(self, csv_path, data_root):
csv_data = []
with open(csv_path) as csvfile:
csv_reader = csv.reader(csvfile, delimiter=';')
for row in csv_reader:
item = ListDataJpeg_test(row[0],
os.path.join(data_root, row[0])
)
csv_data.append(item)
return csv_data
def get_two_way_dict(self, classes):
classes_dict = {}
for i, item in enumerate(classes):
classes_dict[item] = i
classes_dict[i] = item
return classes_dict