Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoStoff committed Jul 7, 2020
1 parent c06add3 commit 56c0f97
Showing 1 changed file with 8 additions and 31 deletions.
39 changes: 8 additions & 31 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from os.path import join
import os
import cv2
from thop import profile
from tqdm import tqdm

from utils.util import ensure_dir, flow2bgr_np
Expand All @@ -21,12 +20,15 @@
model_info = {}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_height_width(data_loader):
for d in data_loader:
return d['events'].shape[-2:] # d['events'] is a ... x H x W voxel grid

def torch2cv2(image):
"""convert torch tensor to format compatible with cv2.imwrite"""
image = torch.squeeze(image) # H x W
image = image.cpu().numpy() # normalize here
image = np.clip(image, 0, 1) # normalize here
image = image.cpu().numpy()
image = np.clip(image, 0, 1)
return (image * 255).astype(np.uint8)

def setup_output_folder(output_folder):
Expand Down Expand Up @@ -91,23 +93,12 @@ def main(args, model):

data_loader = InferenceDataLoader(args.events_file_path, dataset_kwargs=dataset_kwargs, ltype=args.loader_type)

height, width = None, None
for d in data_loader:
height, width = d['events'].shape[-2:]
break
if height is None or width is None:
raise Exception("Could not determine width+height")
height, width = get_height_width(data_loader)

model_info['input_shape'] = height, width
crop = CropParameters(width, height, model.num_encoders)

# count FLOPs
tmp_voxel = crop.pad(torch.randn(1, model_info['num_bins'], height, width).to(device))
model_info['FLOPs'], model_info['Params'] = profile(model, inputs=(tmp_voxel, ))

ts_fname = setup_output_folder(args.output_folder)
if args.output_folder_gt:
ts_fname_gt = setup_output_folder(args.output_folder_gt)

model.reset_states()
for i, item in enumerate(tqdm(data_loader)):
Expand All @@ -123,7 +114,7 @@ def main(args, model):
if item['dt'] == 0:
flow = flow_t.cpu().numpy()
else:
flow = flow_t.cpu().numpy()/item['dt'].numpy()
flow = flow_t.cpu().numpy() / item['dt'].numpy()
ts = item['timestamp'].cpu().numpy()
flow_dict = flow
fname = 'flow_{:010d}.npy'.format(i)
Expand All @@ -142,21 +133,9 @@ def main(args, model):
image = torch2cv2(image)
fname = 'frame_{:010d}.png'.format(i)
cv2.imwrite(join(args.output_folder, fname), image)
if args.output_folder_gt:
image_gt = item['frame']
image_gt = torch2cv2(image_gt)
cv2.imwrite(join(args.output_folder_gt, fname), image_gt)
append_timestamp(ts_fname_gt, fname, item['timestamp'].item())
append_timestamp(ts_fname, fname, item['timestamp'].item())


# def print_model_info():
# print('Input shape: {} x {} x {}'.format(model_info.pop('num_bins'), *model_info.pop('input_shape')))
# print('== Model statistics ==')
# for k, v in model_info.items():
# print('{}: {:.2f} {}'.format(k, *format_power(v)))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Template')
parser.add_argument('--checkpoint_path', required=True, type=str,
Expand All @@ -165,8 +144,6 @@ def main(args, model):
help='path to events (HDF5)')
parser.add_argument('--output_folder', default="/tmp/output", type=str,
help='where to save outputs to')
parser.add_argument('--output_folder_gt', default='', type=str,
help='where to save groundtruth to')
parser.add_argument('--device', default='0', type=str,
help='indices of GPUs to enable')
parser.add_argument('--is_flow', action='store_true',
Expand Down Expand Up @@ -199,9 +176,9 @@ def main(args, model):

if args.device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.device
kwargs = {}
print('Loading checkpoint: {} ...'.format(args.checkpoint_path))
checkpoint = torch.load(args.checkpoint_path)
checkpoint =
assert not (args.e2vid and args.firenet_legacy)
if args.e2vid:
args.legacy_norm = True
Expand Down

0 comments on commit 56c0f97

Please sign in to comment.