Skip to content

Commit

Permalink
allow multiple timesteps of labels as outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
HaydenFaulkner committed Dec 10, 2019
1 parent 4fa7471 commit 3406f06
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions datasets/imgnetvid.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __getitem__(self, idx):
"""
if self._features_dir is not None:
img_path = self.sample_path(idx)
label = self._load_label(idx)[:, :-1] # remove track id
label = self._load_label(self.sample_ids[idx])[:, :-1] # remove track id
if self._window_size > 1: # lets load the temporal window
imgs = None
window_sample_ids = self._windows[self.sample_ids[idx]]
Expand Down Expand Up @@ -185,30 +185,30 @@ def __getitem__(self, idx):

if not self._videos: # frames are samples
img_path = self.sample_path(idx)
label = self._load_label(idx)[:, :-1] # remove track id
label = self._load_label(self.sample_ids[idx])[:, :-1] # remove track id

if self._window_size > 1: # lets load the temporal window
imgs = None
lbls = None
lbls = list()
window_sample_ids = self._windows[self.sample_ids[idx]]

# go through the sample ids for the window
for sid in window_sample_ids:
img_path = self._image_path.format(*self.all_samples[sid])
img = mx.image.imread(img_path)
lbl = self._load_label(self.sample_ids.index(sid))[:, :-1]
lbl = self._load_label(sid)[:, :-1]

if self._transform is not None: # transform each image in the window
img, lbl = self._transform(img, lbl)

lbls.append(lbl)
if imgs is None:
# imgs = img # is the first frame in the window
imgs = mx.ndarray.expand_dims(img, axis=0) # is the first frame in the window
lbls = mx.ndarray.expand_dims(lbl, axis=0)
imgs = mx.nd.expand_dims(img, axis=0) # is the first frame in the window
else:
# imgs = mx.ndarray.concatenate([imgs, img], axis=2) # isn't first frame, concat to the window
imgs = mx.ndarray.concatenate([imgs, mx.ndarray.expand_dims(img, axis=0)], axis=0)
lbls = mx.ndarray.concatenate([lbls, mx.ndarray.expand_dims(lbl, axis=0)], axis=0)
imgs = mx.nd.concatenate([imgs, mx.nd.expand_dims(img, axis=0)], axis=0)

img = imgs
if self._mult_out:
label = lbls
Expand Down Expand Up @@ -242,7 +242,7 @@ def __getitem__(self, idx):
# load the frame and the label
img_id = (sample[0], sample[1], frame_id)
img_path = self._image_path.format(*img_id)
label = self._load_label(idx, frame_id=frame_id)
label = self._load_label(sample_id, frame_id=frame_id)
img = mx.image.imread(img_path, 1)

# transform the image and label
Expand Down Expand Up @@ -318,13 +318,13 @@ def _remove_empties(self):
good_sample_ids = list()
removed = 0
n_boxes = 0
for idx in tqdm(range(len(self.sample_ids)), desc="Removing images that have 0 boxes"):
n_boxes_in_sample = len(self._load_label(idx))
for sid in tqdm(self.sample_ids, desc="Removing images that have 0 boxes"):
n_boxes_in_sample = len(self._load_label(sid))
if n_boxes_in_sample < 1:
removed += 1
else:
n_boxes += n_boxes_in_sample
good_sample_ids.append(self.sample_ids[idx])
good_sample_ids.append(sid)

str_ = "Removed {} out of {} images, leaving {} with {} boxes over {} classes.\n".format(
removed, len(self.sample_ids), len(good_sample_ids), n_boxes, len(self.classes))
Expand Down Expand Up @@ -458,20 +458,19 @@ def _load_samples(self):

return frames

def _load_label(self, idx, frame_id=None):
def _load_label(self, sid, frame_id=None):
"""
Parse the xml annotation files for a sample
Args:
idx (int): the sample index
sid (int): the sample id
frame_id (str): needed if videos=True, will get the label for this particular frame
Returns:
numpy.ndarray : labels of shape (n, 6) - [[xmin, ymin, xmax, ymax, cls_id, trk_id], ...]
"""

sample_id = self.sample_ids[idx]
sample = self.samples[sample_id]
sample = self.all_samples[sid]

anno_path = self._annotations_path.format(*sample)
if self._videos:
Expand All @@ -487,8 +486,8 @@ def _load_label(self, idx, frame_id=None):
height = float(size.find('height').text)

# store the shapes for later usage
if sample_id not in self._im_shapes:
self._im_shapes[sample_id] = (width, height)
if sid not in self._im_shapes:
self._im_shapes[sid] = (width, height)

label = []
for obj in root.iter('object'):
Expand Down Expand Up @@ -556,8 +555,8 @@ def _pad_to_dense(labels, maxlen=100):

def image_size(self, sample_id):
if len(self._im_shapes) == 0:
for idx in tqdm(range(len(self.sample_ids)), desc="populating im_shapes"):
self._load_label(idx)
for sid in tqdm(self.sample_ids, desc="populating im_shapes"):
self._load_label(sid)
return self._im_shapes[sample_id]

def stats(self):
Expand All @@ -584,14 +583,14 @@ def stats(self):
if self._videos:
for frame_id in self.samples[sample_id][2]:
n_frames += 1
for box in self._load_label(idx, frame_id):
for box in self._load_label(self.sample_ids[idx], frame_id):
if int(box[4]) < 0: # not actually a box
continue
n_boxes[int(box[4])] += 1
vid_instances[int(box[4])].add(vid_id+str(box[-1])) # add the track id
else:
n_frames += 1
for box in self._load_label(idx):
for box in self._load_label(self.sample_ids[idx]):
if int(box[4]) < 0: # not actually a box
continue
n_boxes[int(box[4])] += 1
Expand Down Expand Up @@ -649,7 +648,7 @@ def build_coco_json(self):
'height': int(height),
'id': sample_id})

for box in self._load_label(idx):
for box in self._load_label(sample_id):
xywh = [int(box[0]), int(box[1]), int(box[2])-int(box[0]), int(box[3])-int(box[1])]
annotations.append({'image_id': sample_id,
'id': len(annotations),
Expand Down

0 comments on commit 3406f06

Please sign in to comment.