Skip to content

Commit

Permalink
model added
Browse files Browse the repository at this point in the history
  • Loading branch information
DerrickXuNu committed Jul 9, 2022
1 parent a4f2425 commit 947fd58
Show file tree
Hide file tree
Showing 20 changed files with 2,122 additions and 1 deletion.
Empty file added v2xvit/models/__init__.py
Empty file.
122 changes: 122 additions & 0 deletions v2xvit/models/point_pillar_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import torch
import torch.nn as nn

from v2xvit.models.sub_modules.pillar_vfe import PillarVFE
from v2xvit.models.sub_modules.point_pillar_scatter import PointPillarScatter
from v2xvit.models.sub_modules.base_bev_backbone import BaseBEVBackbone
from v2xvit.models.sub_modules.fuse_utils import regroup
from v2xvit.models.sub_modules.downsample_conv import DownsampleConv
from v2xvit.models.sub_modules.naive_compress import NaiveCompressor
from v2xvit.models.sub_modules.v2xvit_basic import V2XTransformer


class PointPillarTransformer(nn.Module):
def __init__(self, args):
super(PointPillarTransformer, self).__init__()

self.max_cav = args['max_cav']
# PIllar VFE
self.pillar_vfe = PillarVFE(args['pillar_vfe'],
num_point_features=4,
voxel_size=args['voxel_size'],
point_cloud_range=args['lidar_range'])
self.scatter = PointPillarScatter(args['point_pillar_scatter'])
self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64)
# used to downsample the feature map for efficient computation
self.shrink_flag = False
if 'shrink_header' in args:
self.shrink_flag = True
self.shrink_conv = DownsampleConv(args['shrink_header'])
self.compression = False

if args['compression'] > 0:
self.compression = True
self.naive_compressor = NaiveCompressor(256, args['compression'])

self.fusion_net = V2XTransformer(args['transformer'])

self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'],
kernel_size=1)
self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'],
kernel_size=1)

if args['backbone_fix']:
self.backbone_fix()

def backbone_fix(self):
"""
Fix the parameters of backbone during finetune on timedelay。
"""
for p in self.pillar_vfe.parameters():
p.requires_grad = False

for p in self.scatter.parameters():
p.requires_grad = False

for p in self.backbone.parameters():
p.requires_grad = False

if self.compression:
for p in self.naive_compressor.parameters():
p.requires_grad = False
if self.shrink_flag:
for p in self.shrink_conv.parameters():
p.requires_grad = False

for p in self.cls_head.parameters():
p.requires_grad = False
for p in self.reg_head.parameters():
p.requires_grad = False

def forward(self, data_dict):
voxel_features = data_dict['processed_lidar']['voxel_features']
voxel_coords = data_dict['processed_lidar']['voxel_coords']
voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
record_len = data_dict['record_len']
spatial_correction_matrix = data_dict['spatial_correction_matrix']

# B, max_cav, 3(dt dv infra), 1, 1
prior_encoding =\
data_dict['prior_encoding'].unsqueeze(-1).unsqueeze(-1)

batch_dict = {'voxel_features': voxel_features,
'voxel_coords': voxel_coords,
'voxel_num_points': voxel_num_points,
'record_len': record_len}
# n, 4 -> n, c
batch_dict = self.pillar_vfe(batch_dict)
# n, c -> N, C, H, W
batch_dict = self.scatter(batch_dict)
batch_dict = self.backbone(batch_dict)

spatial_features_2d = batch_dict['spatial_features_2d']
# downsample feature to reduce memory
if self.shrink_flag:
spatial_features_2d = self.shrink_conv(spatial_features_2d)
# compressor
if self.compression:
spatial_features_2d = self.naive_compressor(spatial_features_2d)
# N, C, H, W -> B, L, C, H, W
regroup_feature, mask = regroup(spatial_features_2d,
record_len,
self.max_cav)
# prior encoding added
prior_encoding = prior_encoding.repeat(1, 1, 1,
regroup_feature.shape[3],
regroup_feature.shape[4])
regroup_feature = torch.cat([regroup_feature, prior_encoding], dim=2)

# b l c h w -> b l h w c
regroup_feature = regroup_feature.permute(0, 1, 3, 4, 2)
# transformer fusion
fused_feature = self.fusion_net(regroup_feature, mask, spatial_correction_matrix)
# b h w c -> b c h w
fused_feature = fused_feature.permute(0, 3, 1, 2)

psm = self.cls_head(fused_feature)
rm = self.reg_head(fused_feature)

output_dict = {'psm': psm,
'rm': rm}

return output_dict
Empty file.
122 changes: 122 additions & 0 deletions v2xvit/models/sub_modules/base_bev_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import numpy as np
import torch
import torch.nn as nn


class BaseBEVBackbone(nn.Module):
def __init__(self, model_cfg, input_channels):
super().__init__()
self.model_cfg = model_cfg

if 'layer_nums' in self.model_cfg:

assert len(self.model_cfg['layer_nums']) == \
len(self.model_cfg['layer_strides']) == \
len(self.model_cfg['num_filters'])

layer_nums = self.model_cfg['layer_nums']
layer_strides = self.model_cfg['layer_strides']
num_filters = self.model_cfg['num_filters']
else:
layer_nums = layer_strides = num_filters = []

if 'upsample_strides' in self.model_cfg:
assert len(self.model_cfg['upsample_strides']) \
== len(self.model_cfg['num_upsample_filter'])

num_upsample_filters = self.model_cfg['num_upsample_filter']
upsample_strides = self.model_cfg['upsample_strides']

else:
upsample_strides = num_upsample_filters = []

num_levels = len(layer_nums)
c_in_list = [input_channels, *num_filters[:-1]]

self.blocks = nn.ModuleList()
self.deblocks = nn.ModuleList()

for idx in range(num_levels):
cur_layers = [
nn.ZeroPad2d(1),
nn.Conv2d(
c_in_list[idx], num_filters[idx], kernel_size=3,
stride=layer_strides[idx], padding=0, bias=False
),
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
]
for k in range(layer_nums[idx]):
cur_layers.extend([
nn.Conv2d(num_filters[idx], num_filters[idx],
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
])

self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0:
stride = upsample_strides[idx]
if stride >= 1:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx],
eps=1e-3, momentum=0.01),
nn.ReLU()
))
else:
stride = np.round(1 / stride).astype(np.int)
self.deblocks.append(nn.Sequential(
nn.Conv2d(
num_filters[idx], num_upsample_filters[idx],
stride,
stride=stride, bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3,
momentum=0.01),
nn.ReLU()
))

c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1],
stride=upsample_strides[-1], bias=False),
nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
nn.ReLU(),
))

self.num_bev_features = c_in

def forward(self, data_dict):
spatial_features = data_dict['spatial_features']

ups = []
ret_dict = {}
x = spatial_features

for i in range(len(self.blocks)):
x = self.blocks[i](x)

stride = int(spatial_features.shape[2] / x.shape[2])
ret_dict['spatial_features_%dx' % stride] = x

if len(self.deblocks) > 0:
ups.append(self.deblocks[i](x))
else:
ups.append(x)

if len(ups) > 1:
x = torch.cat(ups, dim=1)
elif len(ups) == 1:
x = ups[0]

if len(self.deblocks) > len(self.blocks):
x = self.deblocks[-1](x)

data_dict['spatial_features_2d'] = x
return data_dict
124 changes: 124 additions & 0 deletions v2xvit/models/sub_modules/base_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch
from torch import nn

from einops import rearrange


class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)


class CavAttention(nn.Module):
"""
Vanilla CAV attention.
"""
def __init__(self, dim, heads, dim_head=64, dropout=0.1):
super().__init__()
inner_dim = heads * dim_head

self.heads = heads
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, mask, prior_encoding):
# x: (B, L, H, W, C) -> (B, H, W, L, C)
# mask: (B, L)
x = x.permute(0, 2, 3, 1, 4)
# mask: (B, 1, H, W, L, 1)
mask = mask.unsqueeze(1)

# qkv: [(B, H, W, L, C_inner) *3]
qkv = self.to_qkv(x).chunk(3, dim=-1)
# q: (B, M, H, W, L, C)
q, k, v = map(lambda t: rearrange(t, 'b h w l (m c) -> b m h w l c',
m=self.heads), qkv)

# attention, (B, M, H, W, L, L)
att_map = torch.einsum('b m h w i c, b m h w j c -> b m h w i j',
q, k) * self.scale
# add mask
att_map = att_map.masked_fill(mask == 0, -float('inf'))
# softmax
att_map = self.attend(att_map)

# out:(B, M, H, W, L, C_head)
out = torch.einsum('b m h w i j, b m h w j c -> b m h w i c', att_map,
v)
out = rearrange(out, 'b m h w l c -> b h w l (m c)',
m=self.heads)
out = self.to_out(out)
# (B L H W C)
out = out.permute(0, 3, 1, 2, 4)
return out


class BaseEncoder(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, CavAttention(dim,
heads=heads,
dim_head=dim_head,
dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))

def forward(self, x, mask):
for attn, ff in self.layers:
x = attn(x, mask=mask) + x
x = ff(x) + x
return x


class BaseTransformer(nn.Module):
def __init__(self, args):
super().__init__()

dim = args['dim']
depth = args['depth']
heads = args['heads']
dim_head = args['dim_head']
mlp_dim = args['mlp_dim']
dropout = args['dropout']
max_cav = args['max_cav']

self.encoder = BaseEncoder(dim, depth, heads, dim_head, mlp_dim,
dropout)

def forward(self, x, mask):
# B, L, H, W, C
output = self.encoder(x, mask)
# B, H, W, C
output = output[:, 0]

return output
Loading

0 comments on commit 947fd58

Please sign in to comment.