-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a4f2425
commit 947fd58
Showing
20 changed files
with
2,122 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from v2xvit.models.sub_modules.pillar_vfe import PillarVFE | ||
from v2xvit.models.sub_modules.point_pillar_scatter import PointPillarScatter | ||
from v2xvit.models.sub_modules.base_bev_backbone import BaseBEVBackbone | ||
from v2xvit.models.sub_modules.fuse_utils import regroup | ||
from v2xvit.models.sub_modules.downsample_conv import DownsampleConv | ||
from v2xvit.models.sub_modules.naive_compress import NaiveCompressor | ||
from v2xvit.models.sub_modules.v2xvit_basic import V2XTransformer | ||
|
||
|
||
class PointPillarTransformer(nn.Module): | ||
def __init__(self, args): | ||
super(PointPillarTransformer, self).__init__() | ||
|
||
self.max_cav = args['max_cav'] | ||
# PIllar VFE | ||
self.pillar_vfe = PillarVFE(args['pillar_vfe'], | ||
num_point_features=4, | ||
voxel_size=args['voxel_size'], | ||
point_cloud_range=args['lidar_range']) | ||
self.scatter = PointPillarScatter(args['point_pillar_scatter']) | ||
self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64) | ||
# used to downsample the feature map for efficient computation | ||
self.shrink_flag = False | ||
if 'shrink_header' in args: | ||
self.shrink_flag = True | ||
self.shrink_conv = DownsampleConv(args['shrink_header']) | ||
self.compression = False | ||
|
||
if args['compression'] > 0: | ||
self.compression = True | ||
self.naive_compressor = NaiveCompressor(256, args['compression']) | ||
|
||
self.fusion_net = V2XTransformer(args['transformer']) | ||
|
||
self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'], | ||
kernel_size=1) | ||
self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'], | ||
kernel_size=1) | ||
|
||
if args['backbone_fix']: | ||
self.backbone_fix() | ||
|
||
def backbone_fix(self): | ||
""" | ||
Fix the parameters of backbone during finetune on timedelay。 | ||
""" | ||
for p in self.pillar_vfe.parameters(): | ||
p.requires_grad = False | ||
|
||
for p in self.scatter.parameters(): | ||
p.requires_grad = False | ||
|
||
for p in self.backbone.parameters(): | ||
p.requires_grad = False | ||
|
||
if self.compression: | ||
for p in self.naive_compressor.parameters(): | ||
p.requires_grad = False | ||
if self.shrink_flag: | ||
for p in self.shrink_conv.parameters(): | ||
p.requires_grad = False | ||
|
||
for p in self.cls_head.parameters(): | ||
p.requires_grad = False | ||
for p in self.reg_head.parameters(): | ||
p.requires_grad = False | ||
|
||
def forward(self, data_dict): | ||
voxel_features = data_dict['processed_lidar']['voxel_features'] | ||
voxel_coords = data_dict['processed_lidar']['voxel_coords'] | ||
voxel_num_points = data_dict['processed_lidar']['voxel_num_points'] | ||
record_len = data_dict['record_len'] | ||
spatial_correction_matrix = data_dict['spatial_correction_matrix'] | ||
|
||
# B, max_cav, 3(dt dv infra), 1, 1 | ||
prior_encoding =\ | ||
data_dict['prior_encoding'].unsqueeze(-1).unsqueeze(-1) | ||
|
||
batch_dict = {'voxel_features': voxel_features, | ||
'voxel_coords': voxel_coords, | ||
'voxel_num_points': voxel_num_points, | ||
'record_len': record_len} | ||
# n, 4 -> n, c | ||
batch_dict = self.pillar_vfe(batch_dict) | ||
# n, c -> N, C, H, W | ||
batch_dict = self.scatter(batch_dict) | ||
batch_dict = self.backbone(batch_dict) | ||
|
||
spatial_features_2d = batch_dict['spatial_features_2d'] | ||
# downsample feature to reduce memory | ||
if self.shrink_flag: | ||
spatial_features_2d = self.shrink_conv(spatial_features_2d) | ||
# compressor | ||
if self.compression: | ||
spatial_features_2d = self.naive_compressor(spatial_features_2d) | ||
# N, C, H, W -> B, L, C, H, W | ||
regroup_feature, mask = regroup(spatial_features_2d, | ||
record_len, | ||
self.max_cav) | ||
# prior encoding added | ||
prior_encoding = prior_encoding.repeat(1, 1, 1, | ||
regroup_feature.shape[3], | ||
regroup_feature.shape[4]) | ||
regroup_feature = torch.cat([regroup_feature, prior_encoding], dim=2) | ||
|
||
# b l c h w -> b l h w c | ||
regroup_feature = regroup_feature.permute(0, 1, 3, 4, 2) | ||
# transformer fusion | ||
fused_feature = self.fusion_net(regroup_feature, mask, spatial_correction_matrix) | ||
# b h w c -> b c h w | ||
fused_feature = fused_feature.permute(0, 3, 1, 2) | ||
|
||
psm = self.cls_head(fused_feature) | ||
rm = self.reg_head(fused_feature) | ||
|
||
output_dict = {'psm': psm, | ||
'rm': rm} | ||
|
||
return output_dict |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class BaseBEVBackbone(nn.Module): | ||
def __init__(self, model_cfg, input_channels): | ||
super().__init__() | ||
self.model_cfg = model_cfg | ||
|
||
if 'layer_nums' in self.model_cfg: | ||
|
||
assert len(self.model_cfg['layer_nums']) == \ | ||
len(self.model_cfg['layer_strides']) == \ | ||
len(self.model_cfg['num_filters']) | ||
|
||
layer_nums = self.model_cfg['layer_nums'] | ||
layer_strides = self.model_cfg['layer_strides'] | ||
num_filters = self.model_cfg['num_filters'] | ||
else: | ||
layer_nums = layer_strides = num_filters = [] | ||
|
||
if 'upsample_strides' in self.model_cfg: | ||
assert len(self.model_cfg['upsample_strides']) \ | ||
== len(self.model_cfg['num_upsample_filter']) | ||
|
||
num_upsample_filters = self.model_cfg['num_upsample_filter'] | ||
upsample_strides = self.model_cfg['upsample_strides'] | ||
|
||
else: | ||
upsample_strides = num_upsample_filters = [] | ||
|
||
num_levels = len(layer_nums) | ||
c_in_list = [input_channels, *num_filters[:-1]] | ||
|
||
self.blocks = nn.ModuleList() | ||
self.deblocks = nn.ModuleList() | ||
|
||
for idx in range(num_levels): | ||
cur_layers = [ | ||
nn.ZeroPad2d(1), | ||
nn.Conv2d( | ||
c_in_list[idx], num_filters[idx], kernel_size=3, | ||
stride=layer_strides[idx], padding=0, bias=False | ||
), | ||
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01), | ||
nn.ReLU() | ||
] | ||
for k in range(layer_nums[idx]): | ||
cur_layers.extend([ | ||
nn.Conv2d(num_filters[idx], num_filters[idx], | ||
kernel_size=3, padding=1, bias=False), | ||
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01), | ||
nn.ReLU() | ||
]) | ||
|
||
self.blocks.append(nn.Sequential(*cur_layers)) | ||
if len(upsample_strides) > 0: | ||
stride = upsample_strides[idx] | ||
if stride >= 1: | ||
self.deblocks.append(nn.Sequential( | ||
nn.ConvTranspose2d( | ||
num_filters[idx], num_upsample_filters[idx], | ||
upsample_strides[idx], | ||
stride=upsample_strides[idx], bias=False | ||
), | ||
nn.BatchNorm2d(num_upsample_filters[idx], | ||
eps=1e-3, momentum=0.01), | ||
nn.ReLU() | ||
)) | ||
else: | ||
stride = np.round(1 / stride).astype(np.int) | ||
self.deblocks.append(nn.Sequential( | ||
nn.Conv2d( | ||
num_filters[idx], num_upsample_filters[idx], | ||
stride, | ||
stride=stride, bias=False | ||
), | ||
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, | ||
momentum=0.01), | ||
nn.ReLU() | ||
)) | ||
|
||
c_in = sum(num_upsample_filters) | ||
if len(upsample_strides) > num_levels: | ||
self.deblocks.append(nn.Sequential( | ||
nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], | ||
stride=upsample_strides[-1], bias=False), | ||
nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01), | ||
nn.ReLU(), | ||
)) | ||
|
||
self.num_bev_features = c_in | ||
|
||
def forward(self, data_dict): | ||
spatial_features = data_dict['spatial_features'] | ||
|
||
ups = [] | ||
ret_dict = {} | ||
x = spatial_features | ||
|
||
for i in range(len(self.blocks)): | ||
x = self.blocks[i](x) | ||
|
||
stride = int(spatial_features.shape[2] / x.shape[2]) | ||
ret_dict['spatial_features_%dx' % stride] = x | ||
|
||
if len(self.deblocks) > 0: | ||
ups.append(self.deblocks[i](x)) | ||
else: | ||
ups.append(x) | ||
|
||
if len(ups) > 1: | ||
x = torch.cat(ups, dim=1) | ||
elif len(ups) == 1: | ||
x = ups[0] | ||
|
||
if len(self.deblocks) > len(self.blocks): | ||
x = self.deblocks[-1](x) | ||
|
||
data_dict['spatial_features_2d'] = x | ||
return data_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from einops import rearrange | ||
|
||
|
||
class PreNorm(nn.Module): | ||
def __init__(self, dim, fn): | ||
super().__init__() | ||
self.norm = nn.LayerNorm(dim) | ||
self.fn = fn | ||
|
||
def forward(self, x, **kwargs): | ||
return self.fn(self.norm(x), **kwargs) | ||
|
||
|
||
class FeedForward(nn.Module): | ||
def __init__(self, dim, hidden_dim, dropout=0.): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.Linear(dim, hidden_dim), | ||
nn.GELU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(hidden_dim, dim), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class CavAttention(nn.Module): | ||
""" | ||
Vanilla CAV attention. | ||
""" | ||
def __init__(self, dim, heads, dim_head=64, dropout=0.1): | ||
super().__init__() | ||
inner_dim = heads * dim_head | ||
|
||
self.heads = heads | ||
self.scale = dim_head ** -0.5 | ||
|
||
self.attend = nn.Softmax(dim=-1) | ||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | ||
|
||
self.to_out = nn.Sequential( | ||
nn.Linear(inner_dim, dim), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
def forward(self, x, mask, prior_encoding): | ||
# x: (B, L, H, W, C) -> (B, H, W, L, C) | ||
# mask: (B, L) | ||
x = x.permute(0, 2, 3, 1, 4) | ||
# mask: (B, 1, H, W, L, 1) | ||
mask = mask.unsqueeze(1) | ||
|
||
# qkv: [(B, H, W, L, C_inner) *3] | ||
qkv = self.to_qkv(x).chunk(3, dim=-1) | ||
# q: (B, M, H, W, L, C) | ||
q, k, v = map(lambda t: rearrange(t, 'b h w l (m c) -> b m h w l c', | ||
m=self.heads), qkv) | ||
|
||
# attention, (B, M, H, W, L, L) | ||
att_map = torch.einsum('b m h w i c, b m h w j c -> b m h w i j', | ||
q, k) * self.scale | ||
# add mask | ||
att_map = att_map.masked_fill(mask == 0, -float('inf')) | ||
# softmax | ||
att_map = self.attend(att_map) | ||
|
||
# out:(B, M, H, W, L, C_head) | ||
out = torch.einsum('b m h w i j, b m h w j c -> b m h w i c', att_map, | ||
v) | ||
out = rearrange(out, 'b m h w l c -> b h w l (m c)', | ||
m=self.heads) | ||
out = self.to_out(out) | ||
# (B L H W C) | ||
out = out.permute(0, 3, 1, 2, 4) | ||
return out | ||
|
||
|
||
class BaseEncoder(nn.Module): | ||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): | ||
super().__init__() | ||
self.layers = nn.ModuleList([]) | ||
for _ in range(depth): | ||
self.layers.append(nn.ModuleList([ | ||
PreNorm(dim, CavAttention(dim, | ||
heads=heads, | ||
dim_head=dim_head, | ||
dropout=dropout)), | ||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) | ||
])) | ||
|
||
def forward(self, x, mask): | ||
for attn, ff in self.layers: | ||
x = attn(x, mask=mask) + x | ||
x = ff(x) + x | ||
return x | ||
|
||
|
||
class BaseTransformer(nn.Module): | ||
def __init__(self, args): | ||
super().__init__() | ||
|
||
dim = args['dim'] | ||
depth = args['depth'] | ||
heads = args['heads'] | ||
dim_head = args['dim_head'] | ||
mlp_dim = args['mlp_dim'] | ||
dropout = args['dropout'] | ||
max_cav = args['max_cav'] | ||
|
||
self.encoder = BaseEncoder(dim, depth, heads, dim_head, mlp_dim, | ||
dropout) | ||
|
||
def forward(self, x, mask): | ||
# B, L, H, W, C | ||
output = self.encoder(x, mask) | ||
# B, H, W, C | ||
output = output[:, 0] | ||
|
||
return output |
Oops, something went wrong.