This repository has been archived by the owner on Jun 24, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdataset.py
65 lines (51 loc) · 2.31 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os, torch, torchvision
from PIL import Image
from pycocotools.coco import COCO
from torch.utils.data import Dataset
class COCO_detection(Dataset):
def __init__(self, img_dir, ann, transforms=None):
super(COCO_detection, self).__init__()
self.img_dir = img_dir
self.transforms = transforms
self.coco = COCO(ann)
self.ids = list(sorted(self.coco.imgs.keys()))
self.label_map = {raw_label:i for i, raw_label in enumerate(self.coco.getCatIds())}
def _load_image(self, id_):
img = self.coco.loadImgs(id_)[0]['file_name']
return Image.open(os.path.join(self.img_dir, img)).convert('RGB')
def _load_target(self, id_):
if len(self.coco.loadAnns(self.coco.getAnnIds(id_))) == 0: return None, None
bboxs, labels = [], []
for ann in self.coco.loadAnns(self.coco.getAnnIds(id_)):
min_x, min_y, w, h = ann['bbox']
bboxs.append(torch.FloatTensor([min_x, min_y, min_x+w, min_y+h]))
labels.append(self.label_map[ann['category_id']])
bboxs, labels = torch.stack(bboxs, 0), torch.LongTensor(labels)
return bboxs, labels
def __getitem__(self, index):
id_ = self.ids[index]
image, (bboxs, labels) = self._load_image(id_), self._load_target(id_)
if self.transforms is not None:
image, bboxs = self.transforms(image, bboxs)
return image, bboxs, labels
def __len__(self):
return len(self.ids)
class COCO_detection_raw(Dataset):
def __init__(self, img_dir, ann, transforms=None):
super(COCO_detection_visualize, self).__init__()
self.img_dir = img_dir
self.transforms = transforms
self.coco = COCO(ann)
self.ids = list(sorted(self.coco.imgs.keys()))
def _load_image(self, id_):
img = self.coco.loadImgs(id_)[0]['file_name']
return Image.open(os.path.join(self.img_dir, img)).convert('RGB')
def _load_target(self, id_):
return self.coco.loadAnns(self.coco.getAnnIds(id_))
def __getitem__(self, index):
id_ = self.ids[index]
image, target = self._load_image(id_), self._load_target(id_)
if self.transforms is not None: image, target = self.transforms(image, target)
return image, target
def __len__(self):
return len(self.ids)