3
3
4
4
from . import ddn , ddn_loss
5
5
from pcdet .models .model_utils .basic_block_2d import BasicBlock2D
6
-
6
+ from skimage import transform
7
+ import numpy as np
8
+ import torch
9
+ import matplotlib .pyplot as plt
7
10
8
11
class DepthFFN (nn .Module ):
9
12
@@ -30,18 +33,51 @@ def __init__(self, model_cfg, downsample_factor):
30
33
self .channel_reduce = BasicBlock2D (** model_cfg .CHANNEL_REDUCE )
31
34
32
35
# DDN_LOSS is optional
33
- if model_cfg .get ('LOSS ' ,None ) is not None :
34
- self .ddn_loss = ddn_loss .__all__ [model_cfg .LOSS .NAME ](
36
+ if model_cfg .get ('LOSS_ ' ,None ) is not None :
37
+ self .ddn_loss = ddn_loss .__all__ [model_cfg .LOSS_ .NAME ](
35
38
disc_cfg = self .disc_cfg ,
36
39
downsample_factor = downsample_factor ,
37
- ** model_cfg .LOSS .ARGS
40
+ ** model_cfg .LOSS_ .ARGS
38
41
)
42
+
39
43
else :
40
44
self .ddn_loss = None
41
45
self .forward_ret_dict = {}
42
46
43
47
def get_output_feature_dim (self ):
44
48
return self .channel_reduce .out_channels
49
+ # sparse average pooling, by hoiliu
50
+ def sparse_avg_pooling (self , feature_map , size = 2 ):
51
+ feature_map = feature_map .cpu ().detach ()
52
+ device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
53
+ Batch = feature_map .shape [0 ]
54
+ pool_out_list = []
55
+ for i in range (Batch ):
56
+ a = transform .downscale_local_mean (feature_map [i ,:,:],(size ,size ))
57
+ b = transform .downscale_local_mean (feature_map [i ,:,:]!= 0 ,(size ,size ))
58
+ pool_out = a / (b + 1e-10 )
59
+ pool_out_list .append (torch .from_numpy (pool_out ).float ().to (device ))
60
+ pool_out1 = torch .stack (pool_out_list )
61
+ return pool_out1
62
+ def create_depth_target (self , depth_map_target , depth_target_bin ):
63
+ device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
64
+ B , h , w = depth_map_target .shape # torch.Size([2, 47, 156])
65
+ D = 120
66
+ depth_target = torch .from_numpy (np .zeros ([B , D + 1 , h , w ])).float ().to (device )
67
+ for b in range (B ):
68
+ for i in range (h ):
69
+ for j in range (w ):
70
+ bin_value = depth_target_bin [b ,i ,j ]
71
+ depth_target [b ,bin_value ,i ,j ] = 1
72
+ # if bin_value==120: # out of boundary
73
+ # # print("depth_map_target:", depth_map_target[b,i,j])
74
+ # depth_target [b,bin_value,i,j] = 100000
75
+ # elif bin_value>120 or bin_value<0:
76
+ # print("error bin")
77
+ # else:
78
+ # depth_target [b,bin_value,i,j] = 1
79
+
80
+ return depth_target
45
81
46
82
def forward (self , batch_dict ):
47
83
"""
@@ -53,29 +89,45 @@ def forward(self, batch_dict):
53
89
batch_dict:
54
90
frustum_features: (N, C, D, H_out, W_out), Image depth features
55
91
"""
92
+ # print("batch_dict['frame_id']:", batch_dict['frame_id'][0])
56
93
# Pixel-wise depth classification
57
- images = batch_dict ["images" ]
58
- ddn_result = self .ddn (images )
59
- image_features = ddn_result ["features" ]
60
- depth_logits = ddn_result ["logits" ]
61
-
94
+ images = batch_dict ["images" ] #([2, 3, 375, 1242])
95
+ ddn_result = self .ddn (images ) # self.ddn is a pretrained backbone, which is used to generate pretrained depth feature and depth bin
96
+ image_features = ddn_result ["features" ] #([2, 1024, 47, 156])
97
+ depth_logits = ddn_result ["logits" ] #([2, 121, 47, 156])
62
98
# Channel reduce
63
99
if self .channel_reduce is not None :
64
- image_features = self .channel_reduce (image_features )
100
+ image_features = self .channel_reduce (image_features ) # 1024 -> 64
65
101
66
102
# Create image feature plane-sweep volume
67
103
frustum_features = self .create_frustum_features (image_features = image_features ,
68
104
depth_logits = depth_logits )
69
105
batch_dict ["frustum_features" ] = frustum_features
70
-
71
- if self .training :
72
- # depth_maps and gt_boxes2d are optional
73
- self .forward_ret_dict ["depth_maps" ] = batch_dict .get ("depth_maps" ,None )
74
- self .forward_ret_dict ["gt_boxes2d" ] = batch_dict .get ("gt_boxes2d" ,None )
75
- self .forward_ret_dict ["depth_logits" ] = depth_logits
106
+ batch_dict ["image_features" ] = image_features
107
+
108
+ # depth_maps and gt_boxes2d are optional
109
+ self .forward_ret_dict ["depth_maps" ] = batch_dict .get ("depth_maps" ,None )
110
+ self .forward_ret_dict ["gt_boxes2d" ] = batch_dict .get ("gt_boxes2d" ,None )
111
+ self .forward_ret_dict ["depth_logits" ] = depth_logits # torch.Size([2, 121, 47, 156])
112
+ #### New code ####
113
+ #### Create Lidar-image-lije feature plane-sweep volume ###
114
+ self .forward_ret_dict ["depth_maps" ] = self .sparse_avg_pooling (self .forward_ret_dict ["depth_maps" ], 8 )
115
+ # save_path="/home/ipl-pc/cmkd/output/vis_result"+".depth.png"
116
+ # print(self.forward_ret_dict["depth_maps"])
117
+ # exit()
118
+ # plt.imsave(save_path, self.forward_ret_dict["depth_maps"][0,:].cpu().detach())
119
+
120
+ depth_map_target = self .forward_ret_dict ["depth_maps" ] ## ([47, 156])
121
+ depth_target_bin = self .ddn_loss (** self .forward_ret_dict ) # 0-120, total 121 dim
122
+ # print(np.unique(depth_target_bin[0,:].cpu().detach())) # 29-120
123
+ depth_target = self .create_depth_target (depth_map_target , depth_target_bin )
124
+ frustum_features_target = self .create_frustum_features (image_features , depth_target , target = True )
125
+ batch_dict ["frustum_features_target" ] = frustum_features_target
126
+ frustum_features = self .create_frustum_features (image_features = image_features ,
127
+ depth_logits = depth_logits )
76
128
return batch_dict
77
129
78
- def create_frustum_features (self , image_features , depth_logits ):
130
+ def create_frustum_features (self , image_features , depth_logits , target = False ):
79
131
"""
80
132
Create image depth feature volume by multiplying image features with depth distributions
81
133
Args:
@@ -88,15 +140,19 @@ def create_frustum_features(self, image_features, depth_logits):
88
140
depth_dim = 2
89
141
90
142
# Resize to match dimensions
91
- image_features = image_features .unsqueeze (depth_dim )
92
- depth_logits = depth_logits .unsqueeze (channel_dim )
93
-
143
+ image_features = image_features .unsqueeze (depth_dim ) # [2, 64, 47, 156] -> [2, 64, 1, 47, 156]
144
+ depth_logits = depth_logits .unsqueeze (channel_dim ) # [2, 120, 47, 156] -> [2, 1, 120, 47, 156]
94
145
# Apply softmax along depth axis and remove last depth category (> Max Range)
95
- depth_probs = F .softmax (depth_logits , dim = depth_dim )
96
- depth_probs = depth_probs [:, :, :- 1 ]
97
-
146
+ # print("depth_logits:", depth_logits[:,:,-1,:,:])
147
+ if target :
148
+ depth_probs = depth_logits [:, :, :- 1 ] # [2, 1, 120, 47, 156]
149
+ # print("depth_probs:", depth_probs)
150
+ else :
151
+ depth_probs = F .softmax (depth_logits , dim = depth_dim ) # [2, 1, 121, 47, 156]
152
+ depth_probs = depth_probs [:, :, :- 1 ] # [2, 1, 120, 47, 156]
153
+ # print("depth_probs__:", depth_probs)
98
154
# Multiply to form image depth feature volume
99
- frustum_features = depth_probs * image_features
155
+ frustum_features = depth_probs * image_features # [2, 64, 120, 47, 156] = [2, 1, 120, 47, 156] * [2, 64, 1, 47, 156]
100
156
return frustum_features
101
157
102
158
def get_loss (self ):
@@ -108,4 +164,4 @@ def get_loss(self):
108
164
tb_dict: dict[float], All losses to log in tensorboard
109
165
"""
110
166
loss , tb_dict = self .ddn_loss (** self .forward_ret_dict )
111
- return loss , tb_dict
167
+ return loss , tb_dict
0 commit comments