Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 829cf26

Browse files
committedOct 17, 2023
saving edits on depth_ffn
1 parent 000cc6d commit 829cf26

File tree

6 files changed

+209
-205
lines changed

6 files changed

+209
-205
lines changed
 

‎.vscode/launch.json

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
]
2020
//CUDA_VISIBLE_DEVICES=1 python tools/train_cmkd.py --cfg cfgs/kitti_models/CMKD/CMKD-scd/cmkd_kitti_R50_scd_V2_lpcg.yaml --pretrained_lidar_model /home/ipl-pc/cmkd/checkpoints/scd-teacher-kitti.pth --pretrained_img_model /home/ipl-pc/cmkd/checkpoints/cmkd-scd-2161.pth
2121
//CUDA_VISIBLE_DEVICES=1 python train_cmkd.py --cfg /home/ipl-pc/cmkd/tools/cfgs/kitti_models/CMKD/CMKD-scd/cmkd_kitti_R50_scd_V2_lpcg.yaml --pretrained_lidar_model /home/ipl-pc/cmkd/checkpoints/scd-teacher-kitti.pth --pretrained_img_model /home/ipl-pc/cmkd/checkpoints/cmkd-scd-2161.pth
22+
23+
//CUDA_VISIBLE_DEVICES=0 python test_cmkd.py --cfg ./cfgs/kitti_models/CMKD/CMKD-scd/cmkd_kitti_R50_scd_V2_lpcg.yaml --ckpt /home/ipl-pc/cmkd/output/home/ipl-pc/cmkd/tools/cfgs/kitti_models/CMKD/CMKD-scd/cmkd_kitti_R50_scd_V2_lpcg/default/ckpt/checkpoint_epoch_9.pth
24+
25+
2226
},
2327
{
2428
"name": "Test Kitti",

‎pcdet/models/backbones_2d/map_to_bev/conv2d_collapse.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ def forward(self, batch_dict):
5555
batch_dict["spatial_features"] = bev_features_ori
5656

5757
## Disentagle bev-image into two copies ###
58-
# bev_features_new = self.blck_copy(bev_features)
59-
# bev_features_new = self.sam(bev_features_new)
60-
# batch_dict["spatial_features_copy"] = bev_features_new
58+
bev_features_new = self.blck_copy(bev_features)
59+
bev_features_new = self.sam(bev_features_new)
60+
batch_dict["spatial_features_copy"] = bev_features_new
6161

62-
# # #### Image like bev ####
63-
# voxel_features_target = batch_dict["voxel_features_target"]
64-
# bev_features_target = voxel_features_target.flatten(start_dim=1, end_dim=2) # (B, C, Z, Y, X) -> (B, C*Z, Y, X)
65-
# bev_features_target = self.block_target(bev_features_target)
66-
# bev_features_target = self.GC_block(bev_features_target) # (B, C*Z, Y, X) -> (B, C, Y, X)
67-
# batch_dict["spatial_features_target"] = bev_features_target
62+
# #### Image like bev ####
63+
voxel_features_target = batch_dict["voxel_features_target"]
64+
bev_features_target = voxel_features_target.flatten(start_dim=1, end_dim=2) # (B, C, Z, Y, X) -> (B, C*Z, Y, X)
65+
bev_features_target = self.block_target(bev_features_target)
66+
bev_features_target = self.GC_block(bev_features_target) # (B, C*Z, Y, X) -> (B, C, Y, X)
67+
batch_dict["spatial_features_target"] = bev_features_target
6868
# # # #### Fusion ####
6969
# batch_dict["spatial_features_fusion"] = batch_dict["spatial_features_copy"] + 0.2 *batch_dict["spatial_features"]
7070
return batch_dict

‎pcdet/models/backbones_3d/vfe/image_vfe_modules/f2v/frustum_to_voxel.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def forward(self, batch_dict):
5454

5555
##### lidar frumstum #####
5656
# Sample frustum volume to generate voxel volume
57-
# voxel_features_target = self.sampler(input_features=batch_dict["frustum_features_target"],
58-
# grid=grid) # (B, C, X, Y, Z)
59-
# voxel_features_target = voxel_features_target.permute(0, 1, 4, 3, 2)
60-
# batch_dict["voxel_features_target"] = voxel_features_target
57+
voxel_features_target = self.sampler(input_features=batch_dict["frustum_features_target"],
58+
grid=grid) # (B, C, X, Y, Z)
59+
voxel_features_target = voxel_features_target.permute(0, 1, 4, 3, 2)
60+
batch_dict["voxel_features_target"] = voxel_features_target
6161
return batch_dict

‎pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/depth_ffn.py

+81-25
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
from . import ddn, ddn_loss
55
from pcdet.models.model_utils.basic_block_2d import BasicBlock2D
6-
6+
from skimage import transform
7+
import numpy as np
8+
import torch
9+
import matplotlib.pyplot as plt
710

811
class DepthFFN(nn.Module):
912

@@ -30,18 +33,51 @@ def __init__(self, model_cfg, downsample_factor):
3033
self.channel_reduce = BasicBlock2D(**model_cfg.CHANNEL_REDUCE)
3134

3235
# DDN_LOSS is optional
33-
if model_cfg.get('LOSS',None) is not None:
34-
self.ddn_loss = ddn_loss.__all__[model_cfg.LOSS.NAME](
36+
if model_cfg.get('LOSS_',None) is not None:
37+
self.ddn_loss = ddn_loss.__all__[model_cfg.LOSS_.NAME](
3538
disc_cfg=self.disc_cfg,
3639
downsample_factor=downsample_factor,
37-
**model_cfg.LOSS.ARGS
40+
**model_cfg.LOSS_.ARGS
3841
)
42+
3943
else:
4044
self.ddn_loss = None
4145
self.forward_ret_dict = {}
4246

4347
def get_output_feature_dim(self):
4448
return self.channel_reduce.out_channels
49+
# sparse average pooling, by hoiliu
50+
def sparse_avg_pooling(self, feature_map, size=2):
51+
feature_map=feature_map.cpu().detach()
52+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
53+
Batch = feature_map.shape[0]
54+
pool_out_list=[]
55+
for i in range(Batch):
56+
a = transform.downscale_local_mean(feature_map[i,:,:],(size,size))
57+
b = transform.downscale_local_mean(feature_map[i,:,:]!=0,(size,size))
58+
pool_out = a / (b+1e-10)
59+
pool_out_list.append(torch.from_numpy(pool_out).float().to(device))
60+
pool_out1 = torch.stack(pool_out_list)
61+
return pool_out1
62+
def create_depth_target(self, depth_map_target, depth_target_bin):
63+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
64+
B, h, w = depth_map_target.shape # torch.Size([2, 47, 156])
65+
D = 120
66+
depth_target = torch.from_numpy(np.zeros([B, D+1, h, w])).float().to(device)
67+
for b in range(B):
68+
for i in range(h):
69+
for j in range(w):
70+
bin_value = depth_target_bin [b,i,j]
71+
depth_target [b,bin_value,i,j] = 1
72+
# if bin_value==120: # out of boundary
73+
# # print("depth_map_target:", depth_map_target[b,i,j])
74+
# depth_target [b,bin_value,i,j] = 100000
75+
# elif bin_value>120 or bin_value<0:
76+
# print("error bin")
77+
# else:
78+
# depth_target [b,bin_value,i,j] = 1
79+
80+
return depth_target
4581

4682
def forward(self, batch_dict):
4783
"""
@@ -53,29 +89,45 @@ def forward(self, batch_dict):
5389
batch_dict:
5490
frustum_features: (N, C, D, H_out, W_out), Image depth features
5591
"""
92+
# print("batch_dict['frame_id']:", batch_dict['frame_id'][0])
5693
# Pixel-wise depth classification
57-
images = batch_dict["images"]
58-
ddn_result = self.ddn(images)
59-
image_features = ddn_result["features"]
60-
depth_logits = ddn_result["logits"]
61-
94+
images = batch_dict["images"] #([2, 3, 375, 1242])
95+
ddn_result = self.ddn(images) # self.ddn is a pretrained backbone, which is used to generate pretrained depth feature and depth bin
96+
image_features = ddn_result["features"] #([2, 1024, 47, 156])
97+
depth_logits = ddn_result["logits"] #([2, 121, 47, 156])
6298
# Channel reduce
6399
if self.channel_reduce is not None:
64-
image_features = self.channel_reduce(image_features)
100+
image_features = self.channel_reduce(image_features) # 1024 -> 64
65101

66102
# Create image feature plane-sweep volume
67103
frustum_features = self.create_frustum_features(image_features=image_features,
68104
depth_logits=depth_logits)
69105
batch_dict["frustum_features"] = frustum_features
70-
71-
if self.training:
72-
# depth_maps and gt_boxes2d are optional
73-
self.forward_ret_dict["depth_maps"] = batch_dict.get("depth_maps",None)
74-
self.forward_ret_dict["gt_boxes2d"] = batch_dict.get("gt_boxes2d",None)
75-
self.forward_ret_dict["depth_logits"] = depth_logits
106+
batch_dict["image_features"] = image_features
107+
108+
# depth_maps and gt_boxes2d are optional
109+
self.forward_ret_dict["depth_maps"] = batch_dict.get("depth_maps",None)
110+
self.forward_ret_dict["gt_boxes2d"] = batch_dict.get("gt_boxes2d",None)
111+
self.forward_ret_dict["depth_logits"] = depth_logits # torch.Size([2, 121, 47, 156])
112+
#### New code ####
113+
#### Create Lidar-image-lije feature plane-sweep volume ###
114+
self.forward_ret_dict["depth_maps"] = self.sparse_avg_pooling(self.forward_ret_dict["depth_maps"], 8)
115+
# save_path="/home/ipl-pc/cmkd/output/vis_result"+".depth.png"
116+
# print(self.forward_ret_dict["depth_maps"])
117+
# exit()
118+
# plt.imsave(save_path, self.forward_ret_dict["depth_maps"][0,:].cpu().detach())
119+
120+
depth_map_target = self.forward_ret_dict["depth_maps"] ## ([47, 156])
121+
depth_target_bin = self.ddn_loss(**self.forward_ret_dict) # 0-120, total 121 dim
122+
# print(np.unique(depth_target_bin[0,:].cpu().detach())) # 29-120
123+
depth_target = self.create_depth_target(depth_map_target, depth_target_bin)
124+
frustum_features_target = self.create_frustum_features(image_features, depth_target, target=True)
125+
batch_dict["frustum_features_target"] = frustum_features_target
126+
frustum_features = self.create_frustum_features(image_features=image_features,
127+
depth_logits=depth_logits)
76128
return batch_dict
77129

78-
def create_frustum_features(self, image_features, depth_logits):
130+
def create_frustum_features(self, image_features, depth_logits, target=False):
79131
"""
80132
Create image depth feature volume by multiplying image features with depth distributions
81133
Args:
@@ -88,15 +140,19 @@ def create_frustum_features(self, image_features, depth_logits):
88140
depth_dim = 2
89141

90142
# Resize to match dimensions
91-
image_features = image_features.unsqueeze(depth_dim)
92-
depth_logits = depth_logits.unsqueeze(channel_dim)
93-
143+
image_features = image_features.unsqueeze(depth_dim) # [2, 64, 47, 156] -> [2, 64, 1, 47, 156]
144+
depth_logits = depth_logits.unsqueeze(channel_dim) # [2, 120, 47, 156] -> [2, 1, 120, 47, 156]
94145
# Apply softmax along depth axis and remove last depth category (> Max Range)
95-
depth_probs = F.softmax(depth_logits, dim=depth_dim)
96-
depth_probs = depth_probs[:, :, :-1]
97-
146+
# print("depth_logits:", depth_logits[:,:,-1,:,:])
147+
if target:
148+
depth_probs = depth_logits[:, :, :-1] # [2, 1, 120, 47, 156]
149+
# print("depth_probs:", depth_probs)
150+
else:
151+
depth_probs = F.softmax(depth_logits, dim=depth_dim) # [2, 1, 121, 47, 156]
152+
depth_probs = depth_probs[:, :, :-1] # [2, 1, 120, 47, 156]
153+
# print("depth_probs__:", depth_probs)
98154
# Multiply to form image depth feature volume
99-
frustum_features = depth_probs * image_features
155+
frustum_features = depth_probs * image_features # [2, 64, 120, 47, 156] = [2, 1, 120, 47, 156] * [2, 64, 1, 47, 156]
100156
return frustum_features
101157

102158
def get_loss(self):
@@ -108,4 +164,4 @@ def get_loss(self):
108164
tb_dict: dict[float], All losses to log in tensorboard
109165
"""
110166
loss, tb_dict = self.ddn_loss(**self.forward_ret_dict)
111-
return loss, tb_dict
167+
return loss, tb_dict

‎pcdet/models/backbones_3d/vfe/image_vfe_modules/ffn/depth_ffn_.py

-167
This file was deleted.
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Please sign in to comment.