Skip to content

Commit

Permalink
Merge pull request #7 from TimoStoff/cedric/fix_dataloader_index+1_bug
Browse files Browse the repository at this point in the history
untested
  • Loading branch information
cedric-scheerlinck authored Jul 7, 2020
2 parents f4325c7 + 2eb6a4d commit 7aed76f
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,29 +169,28 @@ def __getitem__(self, index, seed=None):

idx0, idx1 = self.get_event_indices(index)
xs, ys, ts, ps = self.get_events(idx0, idx1)
if len(xs) == 0:
xs = torch.zeros((1), dtype=torch.float32)
ys = torch.zeros((1), dtype=torch.float32)
ts = torch.zeros((1), dtype=torch.float32)
ps = torch.zeros((1), dtype=torch.float32)
try:
ts_0, ts_k = ts[0], ts[-1]
except:
ts_0, ts_k = 0, 0
if len(xs) < 3:
voxel = self.get_empty_voxel_grid(self.combined_voxel_channels)
else:
ts_0, ts_k = ts[0], ts[-1]
xs = torch.from_numpy(xs.astype(np.float32))
ys = torch.from_numpy(ys.astype(np.float32))
ts = torch.from_numpy((ts-ts_0).astype(np.float32))
ps = torch.from_numpy(ps.astype(np.float32))
dt = ts[-1] - ts[0]
voxel = self.get_voxel_grid(xs, ys, ts, ps, combined_voxel_channels=self.combined_voxel_channels)

voxel = self.get_voxel_grid(xs, ys, ts, ps, combined_voxel_channels=self.combined_voxel_channels)
voxel = self.transform_voxel(voxel, seed)
dt = ts_k - ts_0

if self.voxel_method['method'] == 'between_frames':
frame = self.get_frame(index + 1)
frame = self.get_frame(index)
frame = self.transform_frame(frame, seed)

if self.has_flow:
flow = self.get_flow(index + 1)
flow = self.get_flow(index)
# convert to displacement (pix)
flow = flow * dt
flow = self.transform_flow(flow, seed)
Expand All @@ -201,7 +200,7 @@ def __getitem__(self, index, seed=None):
item = {'frame': frame,
'flow': flow,
'events': voxel,
'timestamp': torch.tensor(ts_k, dtype=torch.float64),
'timestamp': torch.tensor(self.frame_ts[index], dtype=torch.float64),
'data_source_idx': self.data_source_idx,
'dt': torch.tensor(dt, dtype=torch.float64)}
else:
Expand Down Expand Up @@ -284,6 +283,13 @@ def get_event_indices(self, index):
raise Exception("WARNING: Event indices {},{} out of bounds 0,{}".format(idx0, idx1, self.num_events))
return idx0, idx1

def get_empty_voxel_grid(self, combined_voxel_channels=True):
"""Return an empty voxel grid filled with zeros"""
size = (self.num_bins, *self.sensor_resolution)
if not combined_voxel_channels:
size = (2, *size)
return torch.zeros(size)

def get_voxel_grid(self, xs, ys, ts, ps, combined_voxel_channels=True):
"""
Given events, return voxel grid
Expand Down

0 comments on commit 7aed76f

Please sign in to comment.