-
Notifications
You must be signed in to change notification settings - Fork 107
/
Copy pathmisc.py
131 lines (102 loc) · 4.24 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Authors: Zhaoshuo Li, Xingtong Liu, Francis X. Creighton, Russell H. Taylor, and Mathias Unberath
#
# Copyright (c) 2020. Johns Hopkins University - All rights reserved.
import copy
import numpy as np
import torch
import torch.nn as nn
class NestedTensor(object):
def __init__(self, left, right, disp=None, sampled_cols=None, sampled_rows=None, occ_mask=None,
occ_mask_right=None):
self.left = left
self.right = right
self.disp = disp
self.occ_mask = occ_mask
self.occ_mask_right = occ_mask_right
self.sampled_cols = sampled_cols
self.sampled_rows = sampled_rows
def center_crop(layer, max_height, max_width):
_, _, h, w = layer.size()
xy1 = (w - max_width) // 2
xy2 = (h - max_height) // 2
return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)]
def batched_index_select(source, dim, index):
views = [source.shape[0]] + [1 if i != dim else -1 for i in range(1, len(source.shape))]
expanse = list(source.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(source, dim, index)
def torch_1d_sample(source, sample_points, mode='linear'):
"""
linearly sample source tensor along the last dimension
input:
source [N,D1,D2,D3...,Dn]
sample_points [N,D1,D2,....,Dn-1,1]
output:
[N,D1,D2...,Dn-1]
"""
idx_l = torch.floor(sample_points).long().clamp(0, source.size(-1) - 1)
idx_r = torch.ceil(sample_points).long().clamp(0, source.size(-1) - 1)
if mode == 'linear':
weight_r = sample_points - idx_l
weight_l = 1 - weight_r
elif mode == 'sum':
weight_r = (idx_r != idx_l).int() # we only sum places of non-integer locations
weight_l = 1
else:
raise Exception('mode not recognized')
out = torch.gather(source, -1, idx_l) * weight_l + torch.gather(source, -1, idx_r) * weight_r
return out.squeeze(-1)
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def find_occ_mask(disp_left, disp_right):
"""
find occlusion map
1 indicates occlusion
disp range [0,w]
"""
w = disp_left.shape[-1]
# # left occlusion
# find corresponding pixels in target image
coord = np.linspace(0, w - 1, w)[None,] # 1xW
right_shifted = coord - disp_left
# 1. negative locations will be occlusion
occ_mask_l = right_shifted <= 0
# 2. wrong matches will be occlusion
right_shifted[occ_mask_l] = 0 # set negative locations to 0
right_shifted = right_shifted.astype(np.int)
disp_right_selected = np.take_along_axis(disp_right, right_shifted,
axis=1) # find tgt disparity at src-shifted locations
wrong_matches = np.abs(disp_right_selected - disp_left) > 1 # theoretically, these two should match perfectly
wrong_matches[disp_right_selected <= 0.0] = False
wrong_matches[disp_left <= 0.0] = False
# produce final occ
wrong_matches[occ_mask_l] = True # apply case 1 occlusion to case 2
occ_mask_l = wrong_matches
# # right occlusion
# find corresponding pixels in target image
coord = np.linspace(0, w - 1, w)[None,] # 1xW
left_shifted = coord + disp_right
# 1. negative locations will be occlusion
occ_mask_r = left_shifted >= w
# 2. wrong matches will be occlusion
left_shifted[occ_mask_r] = 0 # set negative locations to 0
left_shifted = left_shifted.astype(np.int)
disp_left_selected = np.take_along_axis(disp_left, left_shifted,
axis=1) # find tgt disparity at src-shifted locations
wrong_matches = np.abs(disp_left_selected - disp_right) > 1 # theoretically, these two should match perfectly
wrong_matches[disp_left_selected <= 0.0] = False
wrong_matches[disp_right <= 0.0] = False
# produce final occ
wrong_matches[occ_mask_r] = True # apply case 1 occlusion to case 2
occ_mask_r = wrong_matches
return occ_mask_l, occ_mask_r
def save_and_clear(idx, output_file):
with open('output-' + str(idx) + '.dat', 'wb') as f:
torch.save(output_file, f)
idx += 1
# clear
for key in output_file:
output_file[key].clear()
return idx