Skip to content

Commit

Permalink
Added color inference
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoStoff committed May 12, 2020
1 parent 1c16fe7 commit 659010b
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 80 deletions.
61 changes: 10 additions & 51 deletions data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class DynamicH5Dataset(Dataset):
"""

def __init__(self, h5_path, transforms=None, sensor_size=None, num_bins=5,
voxel_method=None, max_length=None, combined_voxel_channels=True,
legacy=False):
voxel_method=None, max_length=None, combined_voxel_channels=True):
if transforms is None:
transforms = {}
if voxel_method is None:
Expand All @@ -56,9 +55,10 @@ def __init__(self, h5_path, transforms=None, sensor_size=None, num_bins=5,
self.num_bins = num_bins
self.voxel_method = voxel_method
if sensor_size is None:
self.sensor_size = self.h5_file.attrs['sensor_resolution']
self.sensor_size = self.h5_file.attrs['sensor_resolution'][0:2]
print("sensor size = {}".format(self.sensor_size))
else:
self.sensor_size = sensor_size
self.sensor_size = sensor_size[0:2]

self.num_events = self.h5_file.attrs['num_events']
self.duration = self.h5_file.attrs['duration']
Expand Down Expand Up @@ -103,7 +103,6 @@ def __init__(self, h5_path, transforms=None, sensor_size=None, num_bins=5,
if max_length is not None:
self.length = min(self.length, max_length + 1)
self.combined_voxel_channels = combined_voxel_channels
self.legacy = legacy

def __len__(self):
return self.length
Expand All @@ -128,39 +127,6 @@ def transform_flow(self, flow, seed):
flow = self.transform(flow, is_flow=True)
return flow

@staticmethod
def events_to_voxel_legacy(xs, ys, ts, ps, B, sensor_size=(180, 240), device=None):
if device is None:
device = xs.device
assert (len(xs) == len(ys) and len(ys) == len(ts) and len(ts) == len(ps))
# num_events_per_bin = len(xs) // B
dt = ts[-1] - ts[0]

bins = []
for bi in range(B):
bins.append(torch.zeros(list(sensor_size), device=device))

bin_time_width = dt / (B - 1.0)
lower_bin_ts = ts[0]
upper_bin_ts = lower_bin_ts + bin_time_width
for bi in range(B - 1):
beg = binary_search_torch_tensor(ts, 0, len(ts) - 1, lower_bin_ts)
end = binary_search_torch_tensor(ts, 0, len(ts) - 1, upper_bin_ts) - 1

factor_upper = (ts - lower_bin_ts) / bin_time_width
factor_lower = 1.0 - factor_upper

bins[bi] = bins[bi] + events_to_image_torch(xs[beg:end], ys[beg:end], (factor_lower * ps)[beg:end],
device=device, sensor_size=sensor_size, clip_out_of_range=True)
bins[bi + 1] = bins[bi + 1] + events_to_image_torch(xs[beg:end], ys[beg:end], (factor_upper * ps)[beg:end],
device=device, sensor_size=sensor_size,
clip_out_of_range=True)
lower_bin_ts += bin_time_width
upper_bin_ts += bin_time_width

bins = torch.stack(bins)
return bins

def __getitem__(self, i, seed=None):
assert (i >= 0)
assert (i < self.length)
Expand All @@ -175,12 +141,8 @@ def __getitem__(self, i, seed=None):

timestamp = img_dset.attrs['timestamp']

if self.legacy:
events_start_idx = self.h5_file['images']['image{:09d}'.format(i)].attrs['event_idx']
events_end_idx = self.h5_file['images']['image{:09d}'.format(i + 1)].attrs['event_idx']
else:
events_start_idx = self.h5_file['images']['image{:09d}'.format(i)].attrs['event_idx'] + 1
events_end_idx = img_dset.attrs['event_idx']
events_start_idx = self.h5_file['images']['image{:09d}'.format(i)].attrs['event_idx'] + 1
events_end_idx = img_dset.attrs['event_idx']
elif self.voxel_method['method'] == 'k_events':
events_start_idx = i * self.voxel_method['sliding_window_w']
events_end_idx = self.voxel_method['k'] + events_start_idx
Expand All @@ -204,14 +166,11 @@ def __getitem__(self, i, seed=None):
(self.h5_file['events/ts'][events_start_idx:events_end_idx] - self.t0).astype(np.float32)) # H x W
ps = torch.from_numpy(
(self.h5_file['events/ps'][events_start_idx:events_end_idx] * 2 - 1).astype(np.float32)) # H x W
if self.legacy:
voxel = self.events_to_voxel_legacy(xs, ys, ts, ps, self.num_bins, sensor_size=self.sensor_size).float()
if self.combined_voxel_channels:
voxel = events_to_voxel_torch(xs, ys, ts, ps, self.num_bins, sensor_size=self.sensor_size).float()
else:
if self.combined_voxel_channels:
voxel = events_to_voxel_torch(xs, ys, ts, ps, self.num_bins, sensor_size=self.sensor_size).float()
else:
voxel = events_to_neg_pos_voxel_torch(xs, ys, ts, ps, self.num_bins, sensor_size=self.sensor_size)
voxel = torch.cat([voxel[0], voxel[1]], dim=0).float()
voxel = events_to_neg_pos_voxel_torch(xs, ys, ts, ps, self.num_bins, sensor_size=self.sensor_size)
voxel = torch.cat([voxel[0], voxel[1]], dim=0).float()

if seed is None:
seed = random.randint(0, 2 ** 32)
Expand Down
2 changes: 1 addition & 1 deletion events_contrast_maximization
44 changes: 16 additions & 28 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from utils.util import ensure_dir, flow2bgr_np
from model import model as model_arch
from data_loader.data_loaders import InferenceDataLoader
from model.model import ColorNet
from utils.util import CropParameters
from utils.timers import CudaTimer

Expand All @@ -36,6 +37,8 @@ def load_model(checkpoint):

model = model.to(device)
model.eval()
if args.color:
model = ColorNet(model)
print(model)
for param in model.parameters():
param.requires_grad = False
Expand All @@ -54,6 +57,7 @@ def main(args, model):
'sliding_window_t': args.sliding_window_t}
}
if not args.legacy:
print("Updated style model")
dataset_kwargs['transforms'] = {'RobustNorm': {}}
dataset_kwargs['combined_voxel_channels'] = False

Expand All @@ -79,7 +83,8 @@ def main(args, model):
model.reset_states()
for i, item in enumerate(tqdm(data_loader)):
voxel = item['events'].to(device)
voxel = crop.pad(voxel)
if not args.color:
voxel = crop.pad(voxel)
with CudaTimer('Inference'):
output = model(voxel)
# save sample images, or do something with output here
Expand All @@ -97,11 +102,14 @@ def main(args, model):
fname = 'flow_{:010d}.png'.format(i)
cv2.imwrite(os.path.join(args.output_folder, fname), flow_img)
else:
image = crop.crop(output['image'])
image = torch.squeeze(image) # H x W
image = image.cpu().numpy() # normalize here
image = np.clip(image, 0, 1) # normalize here
image = (image * 255).astype(np.uint8)
if args.color:
image = output['image']
else:
image = crop.crop(output['image'])
image = torch.squeeze(image) # H x W
image = image.cpu().numpy() # normalize here
image = np.clip(image, 0, 1) # normalize here
image = (image * 255).astype(np.uint8)
fname = 'frame_{:010d}.png'.format(i)
cv2.imwrite(join(args.output_folder, fname), image)
ts_file.write('{:.15f}\n'.format(item['timestamp'].item()))
Expand All @@ -128,6 +136,8 @@ def main(args, model):
help='If true, save output to flow npy file')
parser.add_argument('--legacy', action='store_true',
help='Set this if using any of the original networks from ECCV20 paper')
parser.add_argument('--color', action='store_true', default=False,
help='Perform color reconstruction')
parser.add_argument('--voxel_method', default='between_frames', type=str,
help='which method should be used to form the voxels',
choices=['between_frames', 'k_events', 't_seconds'])
Expand All @@ -149,27 +159,5 @@ def main(args, model):
checkpoint = torch.load(args.checkpoint_path)
kwargs['checkpoint'] = checkpoint

# import h5py
# dataset_kwargs = {'transforms': {},
# 'max_length': None,
# 'sensor_size': None,
# 'num_bins': 5,
# 'legacy': True,
# 'voxel_method': {'method':'between_frames'}
# }
# h5_path = "/home/timo/Data2/preprocessed_datasets/h5_voxels/slider_depth_cut.h5"
# data_loader = InferenceDataLoader(args.h5_file_path, dataset_kwargs=dataset_kwargs)
# h5_file = h5py.File(h5_path, 'r')
# for i, item in enumerate(data_loader):
# data_name = "frame_{:09d}".format(i)
# dset = h5_file[data_name]
# voxel = np.stack([bin[:] for bin in dset['voxels'].values()], axis=0) # C x H x W
# new_voxel = item['events']
#
# if True:
# print(np.sum(voxel))
# print(torch.sum(new_voxel))
# #print(voxel)
# #print(new_voxel)
model = load_model(**kwargs)
main(args, model)
54 changes: 54 additions & 0 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .unet import UNetFlow, WNet, UNetFlowNoRecur, UNetRecurrent, UNet
from .submodules import ResidualBlock, ConvGRU, ConvLayer
from utils.color_utils import merge_channels_into_color_image


def copy_states(states):
Expand All @@ -21,6 +22,59 @@ def copy_states(states):
return recursive_clone(states)


class ColorNet(BaseModel):
"""
Split the input events into RGBW channels and feed them to an existing
recurrent model with states.
"""
def __init__(self, model):
super().__init__()
self.model = model
self.channels = {'R': [slice(0, None, 2), slice(0, None, 2)],
'G': [slice(0, None, 2), slice(1, None, 2)],
'B': [slice(1, None, 2), slice(1, None, 2)],
'W': [slice(1, None, 2), slice(0, None, 2)],
'grayscale': [slice(None), slice(None)]}
self.prev_states = {k: self.model.states for k in self.channels}

def reset_states(self):
self.model.reset_states()

@property
def num_encoders(self):
return self.model.num_encoders

def forward(self, event_tensor):
"""
:param event_tensor: N x num_bins x H x W
:return: output dict with RGB image taking values in [0, 1], and
displacement within event_tensor.
"""
height, width = event_tensor.shape[-2:]
crop_halfres = CropParameters(int(width / 2), int(height / 2), self.model.num_encoders)
crop_fullres = CropParameters(width, height, self.model.num_encoders)
color_events = {}
reconstructions_for_each_channel = {}
for channel, s in self.channels.items():
color_events = event_tensor[:, :, s[0], s[1]]
if channel == 'grayscale':
color_events = crop_fullres.pad(color_events)
else:
color_events = crop_halfres.pad(color_events)
self.model.states = self.prev_states[channel]
img = self.model(color_events)['image']
self.prev_states[channel] = self.model.states
if channel == 'grayscale':
img = crop_fullres.crop(img)
else:
img = crop_halfres.crop(img)
img = img[0, 0, ...].cpu().numpy()
img = np.clip(img * 255, 0, 255).astype(np.uint8)
reconstructions_for_each_channel[channel] = img
image_bgr = merge_channels_into_color_image(reconstructions_for_each_channel) # H x W x 3
return {'image': image_bgr}


class WFlowNet(BaseModel):
"""
Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU.
Expand Down
92 changes: 92 additions & 0 deletions utils/color_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from .timers import Timer, CudaTimer
import numpy as np
import cv2


def shift_image(X, dx, dy):
X = np.roll(X, dy, axis=0)
X = np.roll(X, dx, axis=1)
if dy > 0:
X[:dy, :] = np.expand_dims(X[dy, :], axis=0)
elif dy < 0:
X[dy:, :] = np.expand_dims(X[dy, :], axis=0)
if dx > 0:
X[:, :dx] = np.expand_dims(X[:, dx], axis=1)
elif dx < 0:
X[:, dx:] = np.expand_dims(X[:, dx], axis=1)
return X


def upsample_color_image(grayscale_highres, color_lowres_bgr, colorspace='LAB'):
"""
Generate a high res color image from a high res grayscale image, and a low res color image,
using the trick described in:
http://www.planetary.org/blogs/emily-lakdawalla/2013/04231204-image-processing-colorizing-images.html
"""
assert(len(grayscale_highres.shape) == 2)
assert(len(color_lowres_bgr.shape) == 3 and color_lowres_bgr.shape[2] == 3)

if colorspace == 'LAB':
# convert color image to LAB space
lab = cv2.cvtColor(src=color_lowres_bgr, code=cv2.COLOR_BGR2LAB)
# replace lightness channel with the highres image
lab[:, :, 0] = grayscale_highres
# convert back to BGR
color_highres_bgr = cv2.cvtColor(src=lab, code=cv2.COLOR_LAB2BGR)
elif colorspace == 'HSV':
# convert color image to HSV space
hsv = cv2.cvtColor(src=color_lowres_bgr, code=cv2.COLOR_BGR2HSV)
# replace value channel with the highres image
hsv[:, :, 2] = grayscale_highres
# convert back to BGR
color_highres_bgr = cv2.cvtColor(src=hsv, code=cv2.COLOR_HSV2BGR)
elif colorspace == 'HLS':
# convert color image to HLS space
hls = cv2.cvtColor(src=color_lowres_bgr, code=cv2.COLOR_BGR2HLS)
# replace lightness channel with the highres image
hls[:, :, 1] = grayscale_highres
# convert back to BGR
color_highres_bgr = cv2.cvtColor(src=hls, code=cv2.COLOR_HLS2BGR)

return color_highres_bgr


def merge_channels_into_color_image(channels):
"""
Combine a full resolution grayscale reconstruction and four color channels at half resolution
into a color image at full resolution.
:param channels: dictionary containing the four color reconstructions (at quarter resolution),
and the full resolution grayscale reconstruction.
:return a color image at full resolution
"""

with Timer('Merge color channels'):

assert('R' in channels)
assert('G' in channels)
assert('W' in channels)
assert('B' in channels)
assert('grayscale' in channels)

# upsample each channel independently
for channel in ['R', 'G', 'W', 'B']:
channels[channel] = cv2.resize(channels[channel], dsize=None, fx=2, fy=2, interpolation=cv2.INTER_LINEAR)

# Shift the channels so that they all have the same origin
channels['B'] = shift_image(channels['B'], dx=1, dy=1)
channels['G'] = shift_image(channels['G'], dx=1, dy=0)
channels['W'] = shift_image(channels['W'], dx=0, dy=1)

# reconstruct the color image at half the resolution using the reconstructed channels RGBW
reconstruction_bgr = np.dstack([channels['B'],
cv2.addWeighted(src1=channels['G'], alpha=0.5,
src2=channels['W'], beta=0.5,
gamma=0.0, dtype=cv2.CV_8U),
channels['R']])

reconstruction_grayscale = channels['grayscale']

# combine the full res grayscale resolution with the low res to get a full res color image
upsampled_img = upsample_color_image(reconstruction_grayscale, reconstruction_bgr)
return upsampled_img

0 comments on commit 659010b

Please sign in to comment.