Skip to content

Commit

Permalink
Hot pixel filtering supported during inference
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoStoff committed Jul 14, 2020
1 parent e465b0c commit b5bb4a8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
20 changes: 18 additions & 2 deletions data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from utils.data import data_sources
from events_contrast_maximization.utils.event_utils import events_to_voxel_torch, \
events_to_neg_pos_voxel_torch, binary_search_torch_tensor, events_to_image_torch, \
binary_search_h5_dset
binary_search_h5_dset, get_hot_event_mask, save_image
from utils.util import read_json, write_json


Expand Down Expand Up @@ -101,8 +101,10 @@ def find_ts_index(self, timestamp):
"""
raise NotImplementedError


def __init__(self, data_path, transforms={}, sensor_resolution=None, num_bins=5,
voxel_method=None, max_length=None, combined_voxel_channels=True):
voxel_method=None, max_length=None, combined_voxel_channels=True,
filter_hot_events=False):
"""
self.transform applies to event voxels, frames and flow.
self.vox_transform applies to event voxels only.
Expand All @@ -114,6 +116,7 @@ def __init__(self, data_path, transforms={}, sensor_resolution=None, num_bins=5,
self.sensor_resolution = sensor_resolution
self.data_source_idx = -1
self.has_flow = False
self.channels = self.num_bins if combined_voxel_channels else self.num_bins*2

self.sensor_resolution, self.t0, self.tk, self.num_events, self.frame_ts, self.num_frames = \
None, None, None, None, None, None
Expand All @@ -128,6 +131,17 @@ def __init__(self, data_path, transforms={}, sensor_resolution=None, num_bins=5,
self.num_pixels = self.sensor_resolution[0] * self.sensor_resolution[1]
self.duration = self.tk - self.t0

if filter_hot_events:
secs_for_hot_mask = 0.2
hot_pix_percent = 0.01
hot_num = min(self.find_ts_index(secs_for_hot_mask+self.t0), self.num_events)
xs, ys, ts, ps = self.get_events(0, hot_num)
self.hot_events_mask = get_hot_event_mask(xs.astype(np.int), ys.astype(np.int), ps, self.sensor_resolution, num_hot=int(self.num_pixels*hot_pix_percent))
self.hot_events_mask = np.stack([self.hot_events_mask]*self.channels, axis=2).transpose(2,0,1)
self.hot_events_mask = torch.from_numpy(self.hot_events_mask).float()
else:
self.hot_events_mask = np.ones([self.channels, *self.sensor_resolution])

if voxel_method is None:
voxel_method = {'method': 'between_frames'}
self.set_voxel_method(voxel_method)
Expand Down Expand Up @@ -311,6 +325,8 @@ def get_voxel_grid(self, xs, ys, ts, ps, combined_voxel_channels=True):
sensor_size=self.sensor_resolution)
voxel_grid = torch.cat([voxel_grid[0], voxel_grid[1]], 0)

voxel_grid = voxel_grid*self.hot_events_mask

return voxel_grid

def transform_frame(self, frame, seed):
Expand Down
3 changes: 3 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def main(args, model):
'max_length': None,
'sensor_resolution': None,
'num_bins': 5,
'filter_hot_events': args.filter_hot_events,
'voxel_method': {'method': args.voxel_method,
'k': args.k,
't': args.t,
Expand Down Expand Up @@ -158,6 +159,8 @@ def main(args, model):
help='sliding_window size in seconds (required if voxel_method is t_seconds)')
parser.add_argument('--loader_type', default='H5', type=str,
help='Which data format to load (HDF5 recommended)')
parser.add_argument('--filter_hot_events', action='store_true',
help='If true, auto-detect and remove hot pixels')
parser.add_argument('--legacy_norm', action='store_true', default=False,
help='Normalize nonzero entries in voxel to have mean=0, std=1 according to Rebecq20PAMI and Scheerlinck20WACV.'
'If --e2vid or --firenet_legacy are set, --legacy_norm will be set to True (default False).')
Expand Down

0 comments on commit b5bb4a8

Please sign in to comment.