diff --git a/v2xvit/models/__init__.py b/v2xvit/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/v2xvit/models/point_pillar_transformer.py b/v2xvit/models/point_pillar_transformer.py new file mode 100644 index 0000000..c466196 --- /dev/null +++ b/v2xvit/models/point_pillar_transformer.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn + +from v2xvit.models.sub_modules.pillar_vfe import PillarVFE +from v2xvit.models.sub_modules.point_pillar_scatter import PointPillarScatter +from v2xvit.models.sub_modules.base_bev_backbone import BaseBEVBackbone +from v2xvit.models.sub_modules.fuse_utils import regroup +from v2xvit.models.sub_modules.downsample_conv import DownsampleConv +from v2xvit.models.sub_modules.naive_compress import NaiveCompressor +from v2xvit.models.sub_modules.v2xvit_basic import V2XTransformer + + +class PointPillarTransformer(nn.Module): + def __init__(self, args): + super(PointPillarTransformer, self).__init__() + + self.max_cav = args['max_cav'] + # PIllar VFE + self.pillar_vfe = PillarVFE(args['pillar_vfe'], + num_point_features=4, + voxel_size=args['voxel_size'], + point_cloud_range=args['lidar_range']) + self.scatter = PointPillarScatter(args['point_pillar_scatter']) + self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64) + # used to downsample the feature map for efficient computation + self.shrink_flag = False + if 'shrink_header' in args: + self.shrink_flag = True + self.shrink_conv = DownsampleConv(args['shrink_header']) + self.compression = False + + if args['compression'] > 0: + self.compression = True + self.naive_compressor = NaiveCompressor(256, args['compression']) + + self.fusion_net = V2XTransformer(args['transformer']) + + self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'], + kernel_size=1) + self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'], + kernel_size=1) + + if args['backbone_fix']: + self.backbone_fix() + + def backbone_fix(self): + """ + Fix the parameters of backbone during finetune on timedelay。 + """ + for p in self.pillar_vfe.parameters(): + p.requires_grad = False + + for p in self.scatter.parameters(): + p.requires_grad = False + + for p in self.backbone.parameters(): + p.requires_grad = False + + if self.compression: + for p in self.naive_compressor.parameters(): + p.requires_grad = False + if self.shrink_flag: + for p in self.shrink_conv.parameters(): + p.requires_grad = False + + for p in self.cls_head.parameters(): + p.requires_grad = False + for p in self.reg_head.parameters(): + p.requires_grad = False + + def forward(self, data_dict): + voxel_features = data_dict['processed_lidar']['voxel_features'] + voxel_coords = data_dict['processed_lidar']['voxel_coords'] + voxel_num_points = data_dict['processed_lidar']['voxel_num_points'] + record_len = data_dict['record_len'] + spatial_correction_matrix = data_dict['spatial_correction_matrix'] + + # B, max_cav, 3(dt dv infra), 1, 1 + prior_encoding =\ + data_dict['prior_encoding'].unsqueeze(-1).unsqueeze(-1) + + batch_dict = {'voxel_features': voxel_features, + 'voxel_coords': voxel_coords, + 'voxel_num_points': voxel_num_points, + 'record_len': record_len} + # n, 4 -> n, c + batch_dict = self.pillar_vfe(batch_dict) + # n, c -> N, C, H, W + batch_dict = self.scatter(batch_dict) + batch_dict = self.backbone(batch_dict) + + spatial_features_2d = batch_dict['spatial_features_2d'] + # downsample feature to reduce memory + if self.shrink_flag: + spatial_features_2d = self.shrink_conv(spatial_features_2d) + # compressor + if self.compression: + spatial_features_2d = self.naive_compressor(spatial_features_2d) + # N, C, H, W -> B, L, C, H, W + regroup_feature, mask = regroup(spatial_features_2d, + record_len, + self.max_cav) + # prior encoding added + prior_encoding = prior_encoding.repeat(1, 1, 1, + regroup_feature.shape[3], + regroup_feature.shape[4]) + regroup_feature = torch.cat([regroup_feature, prior_encoding], dim=2) + + # b l c h w -> b l h w c + regroup_feature = regroup_feature.permute(0, 1, 3, 4, 2) + # transformer fusion + fused_feature = self.fusion_net(regroup_feature, mask, spatial_correction_matrix) + # b h w c -> b c h w + fused_feature = fused_feature.permute(0, 3, 1, 2) + + psm = self.cls_head(fused_feature) + rm = self.reg_head(fused_feature) + + output_dict = {'psm': psm, + 'rm': rm} + + return output_dict diff --git a/v2xvit/models/sub_modules/__init__.py b/v2xvit/models/sub_modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/v2xvit/models/sub_modules/base_bev_backbone.py b/v2xvit/models/sub_modules/base_bev_backbone.py new file mode 100644 index 0000000..2db16e3 --- /dev/null +++ b/v2xvit/models/sub_modules/base_bev_backbone.py @@ -0,0 +1,122 @@ +import numpy as np +import torch +import torch.nn as nn + + +class BaseBEVBackbone(nn.Module): + def __init__(self, model_cfg, input_channels): + super().__init__() + self.model_cfg = model_cfg + + if 'layer_nums' in self.model_cfg: + + assert len(self.model_cfg['layer_nums']) == \ + len(self.model_cfg['layer_strides']) == \ + len(self.model_cfg['num_filters']) + + layer_nums = self.model_cfg['layer_nums'] + layer_strides = self.model_cfg['layer_strides'] + num_filters = self.model_cfg['num_filters'] + else: + layer_nums = layer_strides = num_filters = [] + + if 'upsample_strides' in self.model_cfg: + assert len(self.model_cfg['upsample_strides']) \ + == len(self.model_cfg['num_upsample_filter']) + + num_upsample_filters = self.model_cfg['num_upsample_filter'] + upsample_strides = self.model_cfg['upsample_strides'] + + else: + upsample_strides = num_upsample_filters = [] + + num_levels = len(layer_nums) + c_in_list = [input_channels, *num_filters[:-1]] + + self.blocks = nn.ModuleList() + self.deblocks = nn.ModuleList() + + for idx in range(num_levels): + cur_layers = [ + nn.ZeroPad2d(1), + nn.Conv2d( + c_in_list[idx], num_filters[idx], kernel_size=3, + stride=layer_strides[idx], padding=0, bias=False + ), + nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01), + nn.ReLU() + ] + for k in range(layer_nums[idx]): + cur_layers.extend([ + nn.Conv2d(num_filters[idx], num_filters[idx], + kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01), + nn.ReLU() + ]) + + self.blocks.append(nn.Sequential(*cur_layers)) + if len(upsample_strides) > 0: + stride = upsample_strides[idx] + if stride >= 1: + self.deblocks.append(nn.Sequential( + nn.ConvTranspose2d( + num_filters[idx], num_upsample_filters[idx], + upsample_strides[idx], + stride=upsample_strides[idx], bias=False + ), + nn.BatchNorm2d(num_upsample_filters[idx], + eps=1e-3, momentum=0.01), + nn.ReLU() + )) + else: + stride = np.round(1 / stride).astype(np.int) + self.deblocks.append(nn.Sequential( + nn.Conv2d( + num_filters[idx], num_upsample_filters[idx], + stride, + stride=stride, bias=False + ), + nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, + momentum=0.01), + nn.ReLU() + )) + + c_in = sum(num_upsample_filters) + if len(upsample_strides) > num_levels: + self.deblocks.append(nn.Sequential( + nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], + stride=upsample_strides[-1], bias=False), + nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01), + nn.ReLU(), + )) + + self.num_bev_features = c_in + + def forward(self, data_dict): + spatial_features = data_dict['spatial_features'] + + ups = [] + ret_dict = {} + x = spatial_features + + for i in range(len(self.blocks)): + x = self.blocks[i](x) + + stride = int(spatial_features.shape[2] / x.shape[2]) + ret_dict['spatial_features_%dx' % stride] = x + + if len(self.deblocks) > 0: + ups.append(self.deblocks[i](x)) + else: + ups.append(x) + + if len(ups) > 1: + x = torch.cat(ups, dim=1) + elif len(ups) == 1: + x = ups[0] + + if len(self.deblocks) > len(self.blocks): + x = self.deblocks[-1](x) + + data_dict['spatial_features_2d'] = x + return data_dict diff --git a/v2xvit/models/sub_modules/base_transformer.py b/v2xvit/models/sub_modules/base_transformer.py new file mode 100644 index 0000000..32a582a --- /dev/null +++ b/v2xvit/models/sub_modules/base_transformer.py @@ -0,0 +1,124 @@ +import torch +from torch import nn + +from einops import rearrange + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class CavAttention(nn.Module): + """ + Vanilla CAV attention. + """ + def __init__(self, dim, heads, dim_head=64, dropout=0.1): + super().__init__() + inner_dim = heads * dim_head + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim=-1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask, prior_encoding): + # x: (B, L, H, W, C) -> (B, H, W, L, C) + # mask: (B, L) + x = x.permute(0, 2, 3, 1, 4) + # mask: (B, 1, H, W, L, 1) + mask = mask.unsqueeze(1) + + # qkv: [(B, H, W, L, C_inner) *3] + qkv = self.to_qkv(x).chunk(3, dim=-1) + # q: (B, M, H, W, L, C) + q, k, v = map(lambda t: rearrange(t, 'b h w l (m c) -> b m h w l c', + m=self.heads), qkv) + + # attention, (B, M, H, W, L, L) + att_map = torch.einsum('b m h w i c, b m h w j c -> b m h w i j', + q, k) * self.scale + # add mask + att_map = att_map.masked_fill(mask == 0, -float('inf')) + # softmax + att_map = self.attend(att_map) + + # out:(B, M, H, W, L, C_head) + out = torch.einsum('b m h w i j, b m h w j c -> b m h w i c', att_map, + v) + out = rearrange(out, 'b m h w l c -> b h w l (m c)', + m=self.heads) + out = self.to_out(out) + # (B L H W C) + out = out.permute(0, 3, 1, 2, 4) + return out + + +class BaseEncoder(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, CavAttention(dim, + heads=heads, + dim_head=dim_head, + dropout=dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) + ])) + + def forward(self, x, mask): + for attn, ff in self.layers: + x = attn(x, mask=mask) + x + x = ff(x) + x + return x + + +class BaseTransformer(nn.Module): + def __init__(self, args): + super().__init__() + + dim = args['dim'] + depth = args['depth'] + heads = args['heads'] + dim_head = args['dim_head'] + mlp_dim = args['mlp_dim'] + dropout = args['dropout'] + max_cav = args['max_cav'] + + self.encoder = BaseEncoder(dim, depth, heads, dim_head, mlp_dim, + dropout) + + def forward(self, x, mask): + # B, L, H, W, C + output = self.encoder(x, mask) + # B, H, W, C + output = output[:, 0] + + return output \ No newline at end of file diff --git a/v2xvit/models/sub_modules/downsample_conv.py b/v2xvit/models/sub_modules/downsample_conv.py new file mode 100644 index 0000000..a03627c --- /dev/null +++ b/v2xvit/models/sub_modules/downsample_conv.py @@ -0,0 +1,52 @@ +""" +Class used to downsample features by 3*3 conv +""" + +import torch +import torch.nn as nn + + +class DoubleConv(nn.Module): + """ + Double convoltuion + Args: + in_channels: input channel num + out_channels: output channel num + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding): + super().__init__() + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class DownsampleConv(nn.Module): + def __init__(self, config): + super(DownsampleConv, self).__init__() + self.layers = nn.ModuleList([]) + input_dim = config['input_dim'] + + for (ksize, dim, stride, padding) in zip(config['kernal_size'], + config['dim'], + config['stride'], + config['padding']): + self.layers.append(DoubleConv(input_dim, + dim, + kernel_size=ksize, + stride=stride, + padding=padding)) + input_dim = dim + + def forward(self, x): + for i in range(len(self.layers)): + x = self.layers[i](x) + return x \ No newline at end of file diff --git a/v2xvit/models/sub_modules/fuse_utils.py b/v2xvit/models/sub_modules/fuse_utils.py new file mode 100644 index 0000000..7de9279 --- /dev/null +++ b/v2xvit/models/sub_modules/fuse_utils.py @@ -0,0 +1,61 @@ +import torch +import numpy as np + +from einops import rearrange +from v2xvit.utils.common_utils import torch_tensor_to_numpy + + +def regroup(dense_feature, record_len, max_len): + """ + Regroup the data based on the record_len. + + Parameters + ---------- + dense_feature : torch.Tensor + N, C, H, W + record_len : list + [sample1_len, sample2_len, ...] + max_len : int + Maximum cav number + + Returns + ------- + regroup_feature : torch.Tensor + B, L, C, H, W + """ + cum_sum_len = list(np.cumsum(torch_tensor_to_numpy(record_len))) + split_features = torch.tensor_split(dense_feature, + cum_sum_len[:-1]) + regroup_features = [] + mask = [] + + for split_feature in split_features: + # M, C, H, W + feature_shape = split_feature.shape + + # the maximum M is 5 as most 5 cavs + padding_len = max_len - feature_shape[0] + mask.append([1] * feature_shape[0] + [0] * padding_len) + + padding_tensor = torch.zeros(padding_len, feature_shape[1], + feature_shape[2], feature_shape[3]) + padding_tensor = padding_tensor.to(split_feature.device) + + split_feature = torch.cat([split_feature, padding_tensor], + dim=0) + + # 1, 5C, H, W + split_feature = split_feature.view(-1, + feature_shape[2], + feature_shape[3]).unsqueeze(0) + regroup_features.append(split_feature) + + # B, 5C, H, W + regroup_features = torch.cat(regroup_features, dim=0) + # B, L, C, H, W + regroup_features = rearrange(regroup_features, + 'b (l c) h w -> b l c h w', + l=max_len) + mask = torch.from_numpy(np.array(mask)).to(regroup_features.device) + + return regroup_features, mask diff --git a/v2xvit/models/sub_modules/naive_compress.py b/v2xvit/models/sub_modules/naive_compress.py new file mode 100644 index 0000000..0800d9a --- /dev/null +++ b/v2xvit/models/sub_modules/naive_compress.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn + + +class NaiveCompressor(nn.Module): + def __init__(self, input_dim, compress_raito): + super().__init__() + self.encoder = nn.Sequential( + nn.Conv2d(input_dim, input_dim//compress_raito, kernel_size=3, + stride=1, padding=1), + nn.BatchNorm2d(input_dim//compress_raito, eps=1e-3, momentum=0.01), + nn.ReLU() + ) + self.decoder = nn.Sequential( + nn.Conv2d(input_dim//compress_raito, input_dim, kernel_size=3, + stride=1, padding=1), + nn.BatchNorm2d(input_dim, eps=1e-3, momentum=0.01), + nn.ReLU(), + nn.Conv2d(input_dim, input_dim, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(input_dim, eps=1e-3, + momentum=0.01), + nn.ReLU() + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + + return x \ No newline at end of file diff --git a/v2xvit/models/sub_modules/pillar_vfe.py b/v2xvit/models/sub_modules/pillar_vfe.py new file mode 100644 index 0000000..0a7c601 --- /dev/null +++ b/v2xvit/models/sub_modules/pillar_vfe.py @@ -0,0 +1,146 @@ +""" +Pillar VFE, credits to OpenPCDet. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PFNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + use_norm=True, + last_layer=False): + super().__init__() + + self.last_vfe = last_layer + self.use_norm = use_norm + if not self.last_vfe: + out_channels = out_channels // 2 + + if self.use_norm: + self.linear = nn.Linear(in_channels, out_channels, bias=False) + self.norm = nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01) + else: + self.linear = nn.Linear(in_channels, out_channels, bias=True) + + self.part = 50000 + + def forward(self, inputs): + if inputs.shape[0] > self.part: + # nn.Linear performs randomly when batch size is too large + num_parts = inputs.shape[0] // self.part + part_linear_out = [self.linear( + inputs[num_part * self.part:(num_part + 1) * self.part]) + for num_part in range(num_parts + 1)] + x = torch.cat(part_linear_out, dim=0) + else: + x = self.linear(inputs) + torch.backends.cudnn.enabled = False + x = self.norm(x.permute(0, 2, 1)).permute(0, 2, + 1) if self.use_norm else x + torch.backends.cudnn.enabled = True + x = F.relu(x) + x_max = torch.max(x, dim=1, keepdim=True)[0] + + if self.last_vfe: + return x_max + else: + x_repeat = x_max.repeat(1, inputs.shape[1], 1) + x_concatenated = torch.cat([x, x_repeat], dim=2) + return x_concatenated + + +class PillarVFE(nn.Module): + def __init__(self, model_cfg, num_point_features, voxel_size, + point_cloud_range): + super().__init__() + self.model_cfg = model_cfg + + self.use_norm = self.model_cfg['use_norm'] + self.with_distance = self.model_cfg['with_distance'] + + self.use_absolute_xyz = self.model_cfg['use_absolute_xyz'] + num_point_features += 6 if self.use_absolute_xyz else 3 + if self.with_distance: + num_point_features += 1 + + self.num_filters = self.model_cfg['num_filters'] + assert len(self.num_filters) > 0 + num_filters = [num_point_features] + list(self.num_filters) + + pfn_layers = [] + for i in range(len(num_filters) - 1): + in_filters = num_filters[i] + out_filters = num_filters[i + 1] + pfn_layers.append( + PFNLayer(in_filters, out_filters, self.use_norm, + last_layer=(i >= len(num_filters) - 2)) + ) + self.pfn_layers = nn.ModuleList(pfn_layers) + + self.voxel_x = voxel_size[0] + self.voxel_y = voxel_size[1] + self.voxel_z = voxel_size[2] + self.x_offset = self.voxel_x / 2 + point_cloud_range[0] + self.y_offset = self.voxel_y / 2 + point_cloud_range[1] + self.z_offset = self.voxel_z / 2 + point_cloud_range[2] + + def get_output_feature_dim(self): + return self.num_filters[-1] + + @staticmethod + def get_paddings_indicator(actual_num, max_num, axis=0): + actual_num = torch.unsqueeze(actual_num, axis + 1) + max_num_shape = [1] * len(actual_num.shape) + max_num_shape[axis + 1] = -1 + max_num = torch.arange(max_num, + dtype=torch.int, + device=actual_num.device).view(max_num_shape) + paddings_indicator = actual_num.int() > max_num + return paddings_indicator + + def forward(self, batch_dict): + + voxel_features, voxel_num_points, coords = \ + batch_dict['voxel_features'], batch_dict['voxel_num_points'], \ + batch_dict['voxel_coords'] + points_mean = \ + voxel_features[:, :, :3].sum(dim=1, keepdim=True) / \ + voxel_num_points.type_as(voxel_features).view(-1, 1, 1) + f_cluster = voxel_features[:, :, :3] - points_mean + + f_center = torch.zeros_like(voxel_features[:, :, :3]) + f_center[:, :, 0] = voxel_features[:, :, 0] - ( + coords[:, 3].to(voxel_features.dtype).unsqueeze( + 1) * self.voxel_x + self.x_offset) + f_center[:, :, 1] = voxel_features[:, :, 1] - ( + coords[:, 2].to(voxel_features.dtype).unsqueeze( + 1) * self.voxel_y + self.y_offset) + f_center[:, :, 2] = voxel_features[:, :, 2] - ( + coords[:, 1].to(voxel_features.dtype).unsqueeze( + 1) * self.voxel_z + self.z_offset) + + if self.use_absolute_xyz: + features = [voxel_features, f_cluster, f_center] + else: + features = [voxel_features[..., 3:], f_cluster, f_center] + + if self.with_distance: + points_dist = torch.norm(voxel_features[:, :, :3], 2, 2, + keepdim=True) + features.append(points_dist) + features = torch.cat(features, dim=-1) + + voxel_count = features.shape[1] + mask = self.get_paddings_indicator(voxel_num_points, voxel_count, + axis=0) + mask = torch.unsqueeze(mask, -1).type_as(voxel_features) + features *= mask + for pfn in self.pfn_layers: + features = pfn(features) + features = features.squeeze() + batch_dict['pillar_features'] = features + return batch_dict diff --git a/v2xvit/models/sub_modules/point_pillar_scatter.py b/v2xvit/models/sub_modules/point_pillar_scatter.py new file mode 100644 index 0000000..166ae28 --- /dev/null +++ b/v2xvit/models/sub_modules/point_pillar_scatter.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn + + +class PointPillarScatter(nn.Module): + def __init__(self, model_cfg): + super().__init__() + + self.model_cfg = model_cfg + self.num_bev_features = self.model_cfg['num_features'] + self.nx, self.ny, self.nz = model_cfg['grid_size'] + assert self.nz == 1 + + def forward(self, batch_dict): + pillar_features, coords = batch_dict['pillar_features'], batch_dict[ + 'voxel_coords'] + batch_spatial_features = [] + batch_size = coords[:, 0].max().int().item() + 1 + + for batch_idx in range(batch_size): + spatial_feature = torch.zeros( + self.num_bev_features, + self.nz * self.nx * self.ny, + dtype=pillar_features.dtype, + device=pillar_features.device) + + batch_mask = coords[:, 0] == batch_idx + this_coords = coords[batch_mask, :] + + indices = this_coords[:, 1] + \ + this_coords[:, 2] * self.nx + \ + this_coords[:, 3] + indices = indices.type(torch.long) + + pillars = pillar_features[batch_mask, :] + pillars = pillars.t() + spatial_feature[:, indices] = pillars + batch_spatial_features.append(spatial_feature) + + batch_spatial_features = \ + torch.stack(batch_spatial_features, 0) + batch_spatial_features = \ + batch_spatial_features.view(batch_size, self.num_bev_features * + self.nz, self.ny, self.nx) + batch_dict['spatial_features'] = batch_spatial_features + + return batch_dict + diff --git a/v2xvit/models/sub_modules/split_attn.py b/v2xvit/models/sub_modules/split_attn.py new file mode 100644 index 0000000..d064d51 --- /dev/null +++ b/v2xvit/models/sub_modules/split_attn.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + # x: (B, L, 1, 1, 3C) + batch = x.size(0) + cav_num = x.size(1) + + if self.radix > 1: + # x: (B, L, 1, 3, C) + x = x.view(batch, + cav_num, + self.cardinality, self.radix, -1) + x = F.softmax(x, dim=3) + # B, 3LC + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttn(nn.Module): + def __init__(self, input_dim): + super(SplitAttn, self).__init__() + self.input_dim = input_dim + + self.fc1 = nn.Linear(input_dim, input_dim, bias=False) + self.bn1 = nn.LayerNorm(input_dim) + self.act1 = nn.ReLU() + self.fc2 = nn.Linear(input_dim, input_dim * 3, bias=False) + + self.rsoftmax = RadixSoftmax(3, 1) + + def forward(self, window_list): + # window list: [(B, L, H, W, C) * 3] + assert len(window_list) == 3, 'only 3 windows are supported' + + sw, mw, bw = window_list[0], window_list[1], window_list[2] + B, L = sw.shape[0], sw.shape[1] + + # global average pooling, B, L, H, W, C + x_gap = sw + mw + bw + # B, L, 1, 1, C + x_gap = x_gap.mean((2, 3), keepdim=True) + x_gap = self.act1(self.bn1(self.fc1(x_gap))) + # B, L, 1, 1, 3C + x_attn = self.fc2(x_gap) + # B L 1 1 3C + x_attn = self.rsoftmax(x_attn).view(B, L, 1, 1, -1) + + out = sw * x_attn[:, :, :, :, 0:self.input_dim] + \ + mw * x_attn[:, :, :, :, self.input_dim:2*self.input_dim] +\ + bw * x_attn[:, :, :, :, self.input_dim*2:] + + return out diff --git a/v2xvit/models/sub_modules/torch_transformation_utils.py b/v2xvit/models/sub_modules/torch_transformation_utils.py new file mode 100644 index 0000000..3b39fbe --- /dev/null +++ b/v2xvit/models/sub_modules/torch_transformation_utils.py @@ -0,0 +1,426 @@ +""" +torch_transformation_utils.py +""" +import os + +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt + + +def get_roi_and_cav_mask(shape, cav_mask, spatial_correction_matrix, + discrete_ratio, downsample_rate): + """ + Get mask for the combination of cav_mask and rorated ROI mask. + Parameters + ---------- + shape : tuple + Shape of (B, L, H, W, C). + cav_mask : torch.Tensor + Shape of (B, L). + spatial_correction_matrix : torch.Tensor + Shape of (B, L, 4, 4) + discrete_ratio : float + Discrete ratio. + downsample_rate : float + Downsample rate. + + Returns + ------- + com_mask : torch.Tensor + Combined mask with shape (B, H, W, L, 1). + + """ + B, L, H, W, C = shape + C = 1 + # (B,L,4,4) + dist_correction_matrix = get_discretized_transformation_matrix( + spatial_correction_matrix, discrete_ratio, + downsample_rate) + # (B*L,2,3) + T = get_transformation_matrix( + dist_correction_matrix.reshape(-1, 2, 3), (H, W)) + # (B,L,1,H,W) + roi_mask = get_rotated_roi((B, L, C, H, W), T) + # (B,L,1,H,W) + com_mask = combine_roi_and_cav_mask(roi_mask, cav_mask) + # (B,H,W,1,L) + com_mask = com_mask.permute(0, 3, 4, 2, 1) + return com_mask + + +def combine_roi_and_cav_mask(roi_mask, cav_mask): + """ + Combine ROI mask and CAV mask + + Parameters + ---------- + roi_mask : torch.Tensor + Mask for ROI region after considering the spatial transformation/correction. + cav_mask : torch.Tensor + Mask for CAV to remove padded 0. + + Returns + ------- + com_mask : torch.Tensor + Combined mask. + """ + # (B, L, 1, 1, 1) + cav_mask = cav_mask.unsqueeze(2).unsqueeze(3).unsqueeze(4) + # (B, L, C, H, W) + cav_mask = cav_mask.expand(roi_mask.shape) + # (B, L, C, H, W) + com_mask = roi_mask * cav_mask + return com_mask + + +def get_rotated_roi(shape, correction_matrix): + """ + Get rorated ROI mask. + + Parameters + ---------- + shape : tuple + Shape of (B,L,C,H,W). + correction_matrix : torch.Tensor + Correction matrix with shape (N,2,3). + + Returns + ------- + roi_mask : torch.Tensor + Roated ROI mask with shape (N,2,3). + + """ + B, L, C, H, W = shape + # To reduce the computation, we only need to calculate the mask for the first channel. + # (B,L,1,H,W) + x = torch.ones((B, L, 1, H, W)).to(correction_matrix.dtype).to( + correction_matrix.device) + # (B*L,1,H,W) + roi_mask = warp_affine(x.reshape(-1, 1, H, W), correction_matrix, + dsize=(H, W), mode="nearest") + # (B,L,C,H,W) + roi_mask = torch.repeat_interleave(roi_mask, C, dim=1).reshape(B, L, C, H, + W) + return roi_mask + + +def get_discretized_transformation_matrix(matrix, discrete_ratio, + downsample_rate): + """ + Get disretized transformation matrix. + Parameters + ---------- + matrix : torch.Tensor + Shape -- (B, L, 4, 4) where B is the batch size, L is the max cav + number. + discrete_ratio : float + Discrete ratio. + downsample_rate : float/int + downsample_rate + + Returns + ------- + matrix : torch.Tensor + Output transformation matrix in 2D with shape (B, L, 2, 3), + including 2D transformation and 2D rotation. + + """ + matrix = matrix[:, :, [0, 1], :][:, :, :, [0, 1, 3]] + # normalize the x,y transformation + matrix[:, :, :, -1] = matrix[:, :, :, -1] \ + / (discrete_ratio * downsample_rate) + + return matrix.type(dtype=torch.float) + + +def _torch_inverse_cast(input): + r""" + Helper function to make torch.inverse work with other than fp32/64. + The function torch.inverse is only implemented for fp32/64 which makes + impossible to be used by fp16 or others. What this function does, + is cast input data type to fp32, apply torch.inverse, + and cast back to the input dtype. + Args: + input : torch.Tensor + Tensor to be inversed. + + Returns: + out : torch.Tensor + Inversed Tensor. + + """ + dtype = input.dtype + if dtype not in (torch.float32, torch.float64): + dtype = torch.float32 + out = torch.inverse(input.to(dtype)).to(input.dtype) + return out + + +def normal_transform_pixel( + height, width, device, dtype, eps=1e-14): + r""" + Compute the normalization matrix from image size in pixels to [-1, 1]. + Args: + height : int + Image height. + width : int + Image width. + device : torch.device + Output tensor devices. + dtype : torch.dtype + Output tensor data type. + eps : float + Epsilon to prevent divide-by-zero errors. + + Returns: + tr_mat : torch.Tensor + Normalized transform with shape :math:`(1, 3, 3)`. + """ + tr_mat = torch.tensor( + [[1.0, 0.0, -1.0], [0.0, 1.0, -1.0], [0.0, 0.0, 1.0]], device=device, + dtype=dtype) # 3x3 + + # prevent divide by zero bugs + width_denom = eps if width == 1 else width - 1.0 + height_denom = eps if height == 1 else height - 1.0 + + tr_mat[0, 0] = tr_mat[0, 0] * 2.0 / width_denom + tr_mat[1, 1] = tr_mat[1, 1] * 2.0 / height_denom + + return tr_mat.unsqueeze(0) # 1x3x3 + + +def eye_like(n, B, device, dtype): + r""" + Return a 2-D tensor with ones on the diagonal and + zeros elsewhere with the same batch size as the input. + Args: + n : int + The number of rows :math:`(n)`. + B : int + Btach size. + device : torch.device + Devices of the output tensor. + dtype : torch.dtype + Data type of the output tensor. + + Returns: + The identity matrix with the shape :math:`(B, n, n)`. + """ + + identity = torch.eye(n, device=device, dtype=dtype) + return identity[None].repeat(B, 1, 1) + + +def normalize_homography(dst_pix_trans_src_pix, dsize_src, dsize_dst=None): + r""" + Normalize a given homography in pixels to [-1, 1]. + Args: + dst_pix_trans_src_pix : torch.Tensor + Homography/ies from source to destination to be normalized with + shape :math:`(B, 3, 3)`. + dsize_src : Tuple[int, int] + Size of the source image (height, width). + dsize_dst : Tuple[int, int] + Size of the destination image (height, width). + + Returns: + dst_norm_trans_src_norm : torch.Tensor + The normalized homography of shape :math:`(B, 3, 3)`. + """ + if dsize_dst is None: + dsize_dst = dsize_src + # source and destination sizes + src_h, src_w = dsize_src + dst_h, dst_w = dsize_dst + device = dst_pix_trans_src_pix.device + dtype = dst_pix_trans_src_pix.dtype + # compute the transformation pixel/norm for src/dst + src_norm_trans_src_pix = normal_transform_pixel(src_h, src_w, device, + dtype).to( + dst_pix_trans_src_pix) + + src_pix_trans_src_norm = _torch_inverse_cast(src_norm_trans_src_pix) + dst_norm_trans_dst_pix = normal_transform_pixel(dst_h, dst_w, device, + dtype).to( + dst_pix_trans_src_pix) + # compute chain transformations + dst_norm_trans_src_norm: torch.Tensor = dst_norm_trans_dst_pix @ ( + dst_pix_trans_src_pix @ src_pix_trans_src_norm) + return dst_norm_trans_src_norm + + +def get_rotation_matrix2d(M, dsize): + r""" + Return rotation matrix for torch.affine_grid based on transformation matrix. + Args: + M : torch.Tensor + Transformation matrix with shape :math:`(B, 2, 3)`. + dsize : Tuple[int, int] + Size of the source image (height, width). + + Returns: + R : torch.Tensor + Rotation matrix with shape :math:`(B, 2, 3)`. + """ + H, W = dsize + B = M.shape[0] + center = torch.Tensor([W / 2, H / 2]).to(M.dtype).to(M.device).unsqueeze(0) + shift_m = eye_like(3, B, M.device, M.dtype) + shift_m[:, :2, 2] = center + + shift_m_inv = eye_like(3, B, M.device, M.dtype) + shift_m_inv[:, :2, 2] = -center + + rotat_m = eye_like(3, B, M.device, M.dtype) + rotat_m[:, :2, :2] = M[:, :2, :2] + affine_m = shift_m @ rotat_m @ shift_m_inv + return affine_m[:, :2, :] # Bx2x3 + + +def get_transformation_matrix(M, dsize): + r""" + Return transformation matrix for torch.affine_grid. + Args: + M : torch.Tensor + Transformation matrix with shape :math:`(N, 2, 3)`. + dsize : Tuple[int, int] + Size of the source image (height, width). + + Returns: + T : torch.Tensor + Transformation matrix with shape :math:`(N, 2, 3)`. + """ + T = get_rotation_matrix2d(M, dsize) + T[..., 2] += M[..., 2] + return T + + +def convert_affinematrix_to_homography(A): + r""" + Convert to homography coordinates + Args: + A : torch.Tensor + The affine matrix with shape :math:`(B,2,3)`. + + Returns: + H : torch.Tensor + The homography matrix with shape of :math:`(B,3,3)`. + """ + H: torch.Tensor = torch.nn.functional.pad(A, [0, 0, 0, 1], "constant", + value=0.0) + H[..., -1, -1] += 1.0 + return H + + +def warp_affine( + src, M, dsize, + mode='bilinear', + padding_mode='zeros', + align_corners=True): + r""" + Transform the src based on transformation matrix M. + Args: + src : torch.Tensor + Input feature map with shape :math:`(B,C,H,W)`. + M : torch.Tensor + Transformation matrix with shape :math:`(B,2,3)`. + dsize : tuple + Tuple of output image H_out and W_out. + mode : str + Interpolation methods for F.grid_sample. + padding_mode : str + Padding methods for F.grid_sample. + align_corners : boolean + Parameter of F.affine_grid. + + Returns: + Transformed features with shape :math:`(B,C,H,W)`. + """ + + B, C, H, W = src.size() + + # we generate a 3x3 transformation matrix from 2x3 affine + M_3x3 = convert_affinematrix_to_homography(M) + dst_norm_trans_src_norm = normalize_homography(M_3x3, (H, W), dsize) + + # src_norm_trans_dst_norm = torch.inverse(dst_norm_trans_src_norm) + src_norm_trans_dst_norm = _torch_inverse_cast(dst_norm_trans_src_norm) + grid = F.affine_grid(src_norm_trans_dst_norm[:, :2, :], + [B, C, dsize[0], dsize[1]], + align_corners=align_corners) + + return F.grid_sample(src.half() if grid.dtype == torch.half else src, grid, + align_corners=align_corners, mode=mode, + padding_mode=padding_mode) + + +class Test: + """ + Test the transformation in this file. + The methods in this class are not supposed to be used outside of this file. + """ + + def __init__(self): + pass + + @staticmethod + def load_img(): + torch.manual_seed(0) + x = torch.randn(1, 5, 16, 400, 200) * 100 + # x = torch.ones(1, 5, 16, 400, 200) + return x + + @staticmethod + def load_raw_transformation_matrix(N): + a = 90 / 180 * np.pi + matrix = torch.Tensor([[np.cos(a), -np.sin(a), 10], + [np.sin(a), np.cos(a), 10]]) + matrix = torch.repeat_interleave(matrix.unsqueeze(0).unsqueeze(0), N, + dim=1) + return matrix + + @staticmethod + def load_raw_transformation_matrix2(N, alpha): + a = alpha / 180 * np.pi + matrix = torch.Tensor([[np.cos(a), -np.sin(a), 0, 0], + [np.sin(a), np.cos(a), 0, 0]]) + matrix = torch.repeat_interleave(matrix.unsqueeze(0).unsqueeze(0), N, + dim=1) + return matrix + + @staticmethod + def test(): + img = Test.load_img() + B, L, C, H, W = img.shape + raw_T = Test.load_raw_transformation_matrix(5) + T = get_transformation_matrix(raw_T.reshape(-1, 2, 3), (H, W)) + img_rot = warp_affine(img.reshape(-1, C, H, W), T, (H, W)) + print(img_rot[0, 0, :, :]) + plt.matshow(img_rot[0, 0, :, :]) + plt.show() + + @staticmethod + def test_combine_roi_and_cav_mask(): + B = 2 + L = 5 + C = 16 + H = 300 + W = 400 + # 2, 5 + cav_mask = torch.Tensor([[1, 1, 1, 0, 0], [1, 0, 0, 0, 0]]) + x = torch.zeros(B, L, C, H, W) + correction_matrix = Test.load_raw_transformation_matrix2(5, 10) + correction_matrix = torch.cat([correction_matrix, correction_matrix], + dim=0) + mask = get_roi_and_cav_mask((B, L, H, W, C), cav_mask, + correction_matrix, 0.4, 4) + plt.matshow(mask[0, :, :, 0, 0]) + plt.show() + + +if __name__ == "__main__": + os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' + Test.test_combine_roi_and_cav_mask() diff --git a/v2xvit/models/sub_modules/v2xvit_basic.py b/v2xvit/models/sub_modules/v2xvit_basic.py new file mode 100644 index 0000000..9934190 --- /dev/null +++ b/v2xvit/models/sub_modules/v2xvit_basic.py @@ -0,0 +1,191 @@ +import math + +from v2xvit.models.sub_modules.base_transformer import * +from v2xvit.models.sub_modules.hmsa import * +from v2xvit.models.sub_modules.mswin import * +from v2xvit.models.sub_modules.torch_transformation_utils import \ + get_transformation_matrix, warp_affine, get_roi_and_cav_mask, \ + get_discretized_transformation_matrix + + +class STTF(nn.Module): + def __init__(self, args): + super(STTF, self).__init__() + self.discrete_ratio = args['voxel_size'][0] + self.downsample_rate = args['downsample_rate'] + + def forward(self, x, mask, spatial_correction_matrix): + x = x.permute(0, 1, 4, 2, 3) + dist_correction_matrix = get_discretized_transformation_matrix( + spatial_correction_matrix, self.discrete_ratio, + self.downsample_rate) + # Only compensate non-ego vehicles + B, L, C, H, W = x.shape + + T = get_transformation_matrix( + dist_correction_matrix[:, 1:, :, :].reshape(-1, 2, 3), (H, W)) + cav_features = warp_affine(x[:, 1:, :, :, :].reshape(-1, C, H, W), T, + (H, W)) + cav_features = cav_features.reshape(B, -1, C, H, W) + x = torch.cat([x[:, 0, :, :, :].unsqueeze(1), cav_features], dim=1) + x = x.permute(0, 1, 3, 4, 2) + return x + + +class RelTemporalEncoding(nn.Module): + """ + Implement the Temporal Encoding (Sinusoid) function. + """ + + def __init__(self, n_hid, RTE_ratio, max_len=100, dropout=0.2): + super(RelTemporalEncoding, self).__init__() + position = torch.arange(0., max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, n_hid, 2) * + -(math.log(10000.0) / n_hid)) + emb = nn.Embedding(max_len, n_hid) + emb.weight.data[:, 0::2] = torch.sin(position * div_term) / math.sqrt( + n_hid) + emb.weight.data[:, 1::2] = torch.cos(position * div_term) / math.sqrt( + n_hid) + emb.requires_grad = False + self.RTE_ratio = RTE_ratio + self.emb = emb + self.lin = nn.Linear(n_hid, n_hid) + + def forward(self, x, t): + # When t has unit of 50ms, rte_ratio=1. + # So we can train on 100ms but test on 50ms + return x + self.lin(self.emb(t * self.RTE_ratio)).unsqueeze( + 0).unsqueeze(1) + + +class RTE(nn.Module): + def __init__(self, dim, RTE_ratio=2): + super(RTE, self).__init__() + self.RTE_ratio = RTE_ratio + + self.emb = RelTemporalEncoding(dim, RTE_ratio=self.RTE_ratio) + + def forward(self, x, dts): + # x: (B,L,H,W,C) + # dts: (B,L) + rte_batch = [] + for b in range(x.shape[0]): + rte_list = [] + for i in range(x.shape[1]): + rte_list.append( + self.emb(x[b, i, :, :, :], dts[b, i]).unsqueeze(0)) + rte_batch.append(torch.cat(rte_list, dim=0).unsqueeze(0)) + return torch.cat(rte_batch, dim=0) + + +class V2XFusionBlock(nn.Module): + def __init__(self, num_blocks, cav_att_config, pwindow_config): + super().__init__() + # first multi-agent attention and then multi-window attention + self.layers = nn.ModuleList([]) + self.num_blocks = num_blocks + + for _ in range(num_blocks): + att = HGTCavAttention(cav_att_config['dim'], + heads=cav_att_config['heads'], + dim_head=cav_att_config['dim_head'], + dropout=cav_att_config['dropout']) if \ + cav_att_config['use_hetero'] else \ + CavAttention(cav_att_config['dim'], + heads=cav_att_config['heads'], + dim_head=cav_att_config['dim_head'], + dropout=cav_att_config['dropout']) + self.layers.append(nn.ModuleList([ + PreNorm(cav_att_config['dim'], att), + PreNorm(cav_att_config['dim'], + PyramidWindowAttention(pwindow_config['dim'], + heads=pwindow_config['heads'], + dim_heads=pwindow_config[ + 'dim_head'], + drop_out=pwindow_config[ + 'dropout'], + window_size=pwindow_config[ + 'window_size'], + relative_pos_embedding= + pwindow_config[ + 'relative_pos_embedding'], + fuse_method=pwindow_config[ + 'fusion_method']))])) + + def forward(self, x, mask, prior_encoding): + for cav_attn, pwindow_attn in self.layers: + x = cav_attn(x, mask=mask, prior_encoding=prior_encoding) + x + x = pwindow_attn(x) + x + return x + + +class V2XTEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + cav_att_config = args['cav_att_config'] + pwindow_att_config = args['pwindow_att_config'] + feed_config = args['feed_forward'] + + num_blocks = args['num_blocks'] + depth = args['depth'] + mlp_dim = feed_config['mlp_dim'] + dropout = feed_config['dropout'] + + self.downsample_rate = args['sttf']['downsample_rate'] + self.discrete_ratio = args['sttf']['voxel_size'][0] + self.use_roi_mask = args['use_roi_mask'] + self.use_RTE = cav_att_config['use_RTE'] + self.RTE_ratio = cav_att_config['RTE_ratio'] + self.sttf = STTF(args['sttf']) + # adjust the channel numbers from 256+3 -> 256 + self.prior_feed = nn.Linear(cav_att_config['dim'] + 3, + cav_att_config['dim']) + self.layers = nn.ModuleList([]) + if self.use_RTE: + self.rte = RTE(cav_att_config['dim'], self.RTE_ratio) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + V2XFusionBlock(num_blocks, cav_att_config, pwindow_att_config), + PreNorm(cav_att_config['dim'], + FeedForward(cav_att_config['dim'], mlp_dim, + dropout=dropout)) + ])) + + def forward(self, x, mask, spatial_correction_matrix): + + # transform the features to the current timestamp + # velocity, time_delay, infra + # (B,L,H,W,3) + prior_encoding = x[..., -3:] + # (B,L,H,W,C) + x = x[..., :-3] + if self.use_RTE: + # dt: (B,L) + dt = prior_encoding[:, :, 0, 0, 1].to(torch.int) + x = self.rte(x, dt) + x = self.sttf(x, mask, spatial_correction_matrix) + com_mask = mask.unsqueeze(1).unsqueeze(2).unsqueeze( + 3) if not self.use_roi_mask else get_roi_and_cav_mask(x.shape, + mask, + spatial_correction_matrix, + self.discrete_ratio, + self.downsample_rate) + for attn, ff in self.layers: + x = attn(x, mask=com_mask, prior_encoding=prior_encoding) + x = ff(x) + x + return x + + +class V2XTransformer(nn.Module): + def __init__(self, args): + super(V2XTransformer, self).__init__() + + encoder_args = args['encoder'] + self.encoder = V2XTEncoder(encoder_args) + + def forward(self, x, mask, spatial_correction_matrix): + output = self.encoder(x, mask, spatial_correction_matrix) + output = output[:, 0] + return output diff --git a/v2xvit/tools/__init__.py b/v2xvit/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/v2xvit/tools/debug_utils.py b/v2xvit/tools/debug_utils.py new file mode 100644 index 0000000..8ba4e6e --- /dev/null +++ b/v2xvit/tools/debug_utils.py @@ -0,0 +1,66 @@ +import argparse + +import torch +from torch.utils.data import DataLoader + +import v2xvit.hypes_yaml.yaml_utils as yaml_utils +from v2xvit.tools import train_utils +from v2xvit.data_utils.datasets import build_dataset +from v2xvit.visualization import vis_utils + + +def test_parser(): + parser = argparse.ArgumentParser(description="synthetic data generation") + parser.add_argument('--model_dir', type=str, required=True, + help='Continued training path') + parser.add_argument('--fusion_method', type=str, default='late', + help='late, early or intermediate') + opt = parser.parse_args() + return opt + + +def test_bev_post_processing(): + opt = test_parser() + assert opt.fusion_method in ['late', 'early', 'intermediate'] + + hypes = yaml_utils.load_yaml(None, opt) + + print('Dataset Building') + opencood_dataset = build_dataset(hypes, visualize=True, train=False) + data_loader = DataLoader(opencood_dataset, + batch_size=1, + num_workers=0, + collate_fn=opencood_dataset.collate_batch_test, + shuffle=False, + pin_memory=False, + drop_last=False) + + print('Creating Model') + model = train_utils.create_model(hypes) + # we assume gpu is necessary + if torch.cuda.is_available(): + model.cuda() + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + print('Loading Model from checkpoint') + saved_path = opt.model_dir + _, model = train_utils.load_saved_model(saved_path, model) + model.eval() + for i, batch_data in enumerate(data_loader): + batch_data = train_utils.to_device(batch_data, device) + label_map = batch_data["ego"]["label_dict"]["label_map"] + output_dict = { + "cls": label_map[:, 0, :, :], + "reg": label_map[:, 1:, :, :] + } + gt_box_tensor, _ = opencood_dataset.post_processor.post_process_debug( + batch_data["ego"], output_dict) + vis_utils.visualize_single_sample_output_bev(gt_box_tensor, + batch_data['ego'][ + 'origin_lidar'].squeeze( + 0), + opencood_dataset) + + +if __name__ == '__main__': + test_bev_post_processing() diff --git a/v2xvit/tools/inference.py b/v2xvit/tools/inference.py new file mode 100644 index 0000000..cf20ac1 --- /dev/null +++ b/v2xvit/tools/inference.py @@ -0,0 +1,195 @@ +import argparse +import os +import time + +import torch +import open3d as o3d +from torch.utils.data import DataLoader + +import v2xvit.hypes_yaml.yaml_utils as yaml_utils +from v2xvit.tools import train_utils, infrence_utils +from v2xvit.data_utils.datasets import build_dataset +from v2xvit.visualization import vis_utils +from v2xvit.utils import eval_utils + + +def test_parser(): + parser = argparse.ArgumentParser(description="synthetic data generation") + parser.add_argument('--model_dir', type=str, required=True, + help='Continued training path') + parser.add_argument('--fusion_method', required=True, type=str, + default='late', + help='late, early or intermediate') + parser.add_argument('--show_vis', action='store_true', + help='whether to show image visualization result') + parser.add_argument('--show_sequence', action='store_true', + help='whether to show video visualization result.' + 'it can note be set true with show_vis together ') + parser.add_argument('--save_vis', action='store_true', + help='whether to save visualization result') + parser.add_argument('--save_npy', action='store_true', + help='whether to save prediction and gt result' + 'in npy file') + opt = parser.parse_args() + return opt + + +def main(): + opt = test_parser() + assert opt.fusion_method in ['late', 'early', 'intermediate'] + assert not (opt.show_vis and opt.show_sequence), \ + 'you can only visualize ' \ + 'the results in single ' \ + 'image mode or video mode' + + hypes = yaml_utils.load_yaml(None, opt) + + print('Dataset Building') + opencood_dataset = build_dataset(hypes, visualize=True, train=False) + data_loader = DataLoader(opencood_dataset, + batch_size=1, + num_workers=10, + collate_fn=opencood_dataset.collate_batch_test, + shuffle=False, + pin_memory=False, + drop_last=False) + + print('Creating Model') + model = train_utils.create_model(hypes) + # we assume gpu is necessary + if torch.cuda.is_available(): + model.cuda() + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + print('Loading Model from checkpoint') + saved_path = opt.model_dir + _, model = train_utils.load_saved_model(saved_path, model) + model.eval() + + # Create the dictionary for evaluation + result_stat = {0.3: {'tp': [], 'fp': [], 'gt': 0}, + 0.5: {'tp': [], 'fp': [], 'gt': 0}, + 0.7: {'tp': [], 'fp': [], 'gt': 0}} + + if opt.show_sequence: + vis = o3d.visualization.Visualizer() + vis.create_window() + + vis.get_render_option().background_color = [0.05, 0.05, 0.05] + vis.get_render_option().point_size = 1.0 + vis.get_render_option().show_coordinate_frame = True + + # used to visualize lidar points + vis_pcd = o3d.geometry.PointCloud() + # used to visualize object bounding box, maximum 50 + vis_aabbs_gt = [] + vis_aabbs_pred = [] + for _ in range(50): + vis_aabbs_gt.append(o3d.geometry.LineSet()) + vis_aabbs_pred.append(o3d.geometry.LineSet()) + + for i, batch_data in enumerate(data_loader): + print(i) + with torch.no_grad(): + torch.cuda.synchronize() + batch_data = train_utils.to_device(batch_data, device) + if opt.fusion_method == 'late': + pred_box_tensor, pred_score, gt_box_tensor = \ + infrence_utils.inference_late_fusion(batch_data, + model, + opencood_dataset) + elif opt.fusion_method == 'early': + pred_box_tensor, pred_score, gt_box_tensor = \ + infrence_utils.inference_early_fusion(batch_data, + model, + opencood_dataset) + elif opt.fusion_method == 'intermediate': + pred_box_tensor, pred_score, gt_box_tensor = \ + infrence_utils.inference_intermediate_fusion(batch_data, + model, + opencood_dataset) + else: + raise NotImplementedError('Only early, late and intermediate' + 'fusion is supported.') + eval_utils.caluclate_tp_fp(pred_box_tensor, + pred_score, + gt_box_tensor, + result_stat, + 0.3) + eval_utils.caluclate_tp_fp(pred_box_tensor, + pred_score, + gt_box_tensor, + result_stat, + 0.5) + eval_utils.caluclate_tp_fp(pred_box_tensor, + pred_score, + gt_box_tensor, + result_stat, + 0.7) + if opt.save_npy: + npy_save_path = os.path.join(opt.model_dir, 'npy') + if not os.path.exists(npy_save_path): + os.makedirs(npy_save_path) + infrence_utils.save_prediction_gt(pred_box_tensor, + gt_box_tensor, + batch_data['ego'][ + 'origin_lidar'][0], + i, + npy_save_path) + + if opt.show_vis or opt.save_vis: + vis_save_path = '' + if opt.save_vis: + vis_save_path = os.path.join(opt.model_dir, 'vis') + if not os.path.exists(vis_save_path): + os.makedirs(vis_save_path) + vis_save_path = os.path.join(vis_save_path, '%05d.png' % i) + + opencood_dataset.visualize_result(pred_box_tensor, + gt_box_tensor, + batch_data['ego'][ + 'origin_lidar'][0], + opt.show_vis, + vis_save_path, + dataset=opencood_dataset) + + if opt.show_sequence: + pcd, pred_o3d_box, gt_o3d_box = \ + vis_utils.visualize_inference_sample_dataloader( + pred_box_tensor, + gt_box_tensor, + batch_data['ego']['origin_lidar'][0], + vis_pcd, + mode='constant' + ) + if i == 0: + vis.add_geometry(pcd) + vis_utils.linset_assign_list(vis, + vis_aabbs_pred, + pred_o3d_box, + update_mode='add') + + vis_utils.linset_assign_list(vis, + vis_aabbs_gt, + gt_o3d_box, + update_mode='add') + + vis_utils.linset_assign_list(vis, + vis_aabbs_pred, + pred_o3d_box) + vis_utils.linset_assign_list(vis, + vis_aabbs_gt, + gt_o3d_box) + vis.update_geometry(pcd) + vis.poll_events() + vis.update_renderer() + time.sleep(0.001) + + eval_utils.eval_final_results(result_stat, + opt.model_dir) + if opt.show_sequence: + vis.destroy_window() + + +if __name__ == '__main__': + main() diff --git a/v2xvit/tools/infrence_utils.py b/v2xvit/tools/infrence_utils.py new file mode 100644 index 0000000..693eddc --- /dev/null +++ b/v2xvit/tools/infrence_utils.py @@ -0,0 +1,98 @@ +import os +from collections import OrderedDict + +import numpy as np +import torch + +from v2xvit.utils.common_utils import torch_tensor_to_numpy + + +def inference_late_fusion(batch_data, model, dataset): + """ + Model inference for late fusion. + + Parameters + ---------- + batch_data : dict + model : opencood.object + dataset : opencood.LateFusionDataset + + Returns + ------- + pred_box_tensor : torch.Tensor + The tensor of prediction bounding box after NMS. + gt_box_tensor : torch.Tensor + The tensor of gt bounding box. + """ + output_dict = OrderedDict() + + for cav_id, cav_content in batch_data.items(): + output_dict[cav_id] = model(cav_content) + + pred_box_tensor, pred_score, gt_box_tensor = \ + dataset.post_process(batch_data, + output_dict) + + return pred_box_tensor, pred_score, gt_box_tensor + + +def inference_early_fusion(batch_data, model, dataset): + """ + Model inference for early fusion. + + Parameters + ---------- + batch_data : dict + model : opencood.object + dataset : opencood.EarlyFusionDataset + + Returns + ------- + pred_box_tensor : torch.Tensor + The tensor of prediction bounding box after NMS. + gt_box_tensor : torch.Tensor + The tensor of gt bounding box. + """ + output_dict = OrderedDict() + cav_content = batch_data['ego'] + + output_dict['ego'] = model(cav_content) + + pred_box_tensor, pred_score, gt_box_tensor = \ + dataset.post_process(batch_data, + output_dict) + + return pred_box_tensor, pred_score, gt_box_tensor + + +def inference_intermediate_fusion(batch_data, model, dataset): + """ + Model inference for early fusion. + + Parameters + ---------- + batch_data : dict + model : opencood.object + dataset : opencood.EarlyFusionDataset + + Returns + ------- + pred_box_tensor : torch.Tensor + The tensor of prediction bounding box after NMS. + gt_box_tensor : torch.Tensor + The tensor of gt bounding box. + """ + return inference_early_fusion(batch_data, model, dataset) + + +def save_prediction_gt(pred_tensor, gt_tensor, pcd, timestamp, save_path): + """ + Save prediction and gt tensor to txt file. + """ + pred_np = torch_tensor_to_numpy(pred_tensor) + gt_np = torch_tensor_to_numpy(gt_tensor) + pcd_np = torch_tensor_to_numpy(pcd) + + np.save(os.path.join(save_path, '%04d_pcd.npy' % timestamp), pcd_np) + np.save(os.path.join(save_path, '%04d_pred.npy' % timestamp), pred_np) + np.save(os.path.join(save_path, '%04d_gt.npy' % timestamp), gt_np) diff --git a/v2xvit/tools/train.py b/v2xvit/tools/train.py new file mode 100644 index 0000000..01d91b7 --- /dev/null +++ b/v2xvit/tools/train.py @@ -0,0 +1,154 @@ +import argparse +import os +import statistics + +import torch +import tqdm +from torch.utils.data import DataLoader +from tensorboardX import SummaryWriter + +import v2xvit.hypes_yaml.yaml_utils as yaml_utils +from v2xvit.tools import train_utils +from v2xvit.data_utils.datasets import build_dataset + + +def train_parser(): + parser = argparse.ArgumentParser(description="synthetic data generation") + parser.add_argument("--hypes_yaml", type=str, required=True, + help='data generation yaml file needed ') + parser.add_argument('--model_dir', default='', + help='Continued training path') + parser.add_argument("--half", action='store_true', help="whether train with half precision") + opt = parser.parse_args() + return opt + + +def main(): + opt = train_parser() + hypes = yaml_utils.load_yaml(opt.hypes_yaml, opt) + + print('Dataset Building') + opencood_train_dataset = build_dataset(hypes, visualize=False, train=True) + opencood_validate_dataset = build_dataset(hypes, + visualize=False, + train=False) + + train_loader = DataLoader(opencood_train_dataset, + batch_size=hypes['train_params']['batch_size'], + num_workers=8, + collate_fn=opencood_train_dataset.collate_batch_train, + shuffle=True, + pin_memory=False, + drop_last=True) + val_loader = DataLoader(opencood_validate_dataset, + batch_size=hypes['train_params']['batch_size'], + num_workers=8, + collate_fn=opencood_train_dataset.collate_batch_train, + shuffle=False, + pin_memory=False, + drop_last=True) + + print('Creating Model') + model = train_utils.create_model(hypes) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # we assume gpu is necessary + if torch.cuda.is_available(): + model.to(device) + + # define the loss + criterion = train_utils.create_loss(hypes) + + # optimizer setup + optimizer = train_utils.setup_optimizer(hypes, model) + # lr scheduler setup + scheduler = train_utils.setup_lr_schedular(hypes, optimizer) + + # if we want to train from last checkpoint. + if opt.model_dir: + saved_path = opt.model_dir + init_epoch, model = train_utils.load_saved_model(saved_path, model) + + else: + init_epoch = 0 + # if we train the model from scratch, we need to create a folder + # to save the model, + saved_path = train_utils.setup_train(hypes) + + # record training + writer = SummaryWriter(saved_path) + + # half precision training + if opt.half: + scaler = torch.cuda.amp.GradScaler() + + print('Training start') + epoches = hypes['train_params']['epoches'] + # used to help schedule learning rate + for epoch in range(init_epoch, max(epoches, init_epoch)): + scheduler.step(epoch) + for param_group in optimizer.param_groups: + print('learning rate %f' % param_group["lr"]) + pbar2 = tqdm.tqdm(total=len(train_loader), leave=True) + for i, batch_data in enumerate(train_loader): + # the model will be evaluation mode during validation + model.train() + model.zero_grad() + optimizer.zero_grad() + + batch_data = train_utils.to_device(batch_data, device) + + # case1 : late fusion train --> only ego needed + # case2 : early fusion train --> all data projected to ego + # case3 : intermediate fusion --> ['ego']['processed_lidar'] + # becomes a list, which containing all data from other cavs + # as well + if not opt.half: + ouput_dict = model(batch_data['ego']) + # first argument is always your output dictionary, + # second argument is always your label dictionary. + final_loss = criterion(ouput_dict, batch_data['ego']['label_dict']) + else: + with torch.cuda.amp.autocast(): + ouput_dict = model(batch_data['ego']) + final_loss = criterion(ouput_dict, batch_data['ego']['label_dict']) + + criterion.logging(epoch, i, len(train_loader), writer, pbar=pbar2) + pbar2.update(1) + # back-propagation + if not opt.half: + final_loss.backward() + optimizer.step() + else: + scaler.scale(final_loss).backward() + scaler.step(optimizer) + scaler.update() + if epoch % hypes['train_params']['eval_freq'] == 0: + valid_ave_loss = [] + + with torch.no_grad(): + for i, batch_data in enumerate(val_loader): + model.eval() + + batch_data = train_utils.to_device(batch_data, device) + ouput_dict = model(batch_data['ego']) + + final_loss = criterion(ouput_dict, + batch_data['ego']['label_dict']) + valid_ave_loss.append(final_loss.item()) + valid_ave_loss = statistics.mean(valid_ave_loss) + print('At epoch %d, the validation loss is %f' % (epoch, + valid_ave_loss)) + + writer.add_scalar('Validate_Loss', valid_ave_loss, epoch) + + if epoch % hypes['train_params']['save_freq'] == 0: + torch.save(model.state_dict(), + os.path.join(saved_path, + 'net_epoch%d.pth' % (epoch + 1))) + + print('Training Finished, checkpoints saved to %s' % saved_path) + + +if __name__ == '__main__': + main() diff --git a/v2xvit/tools/train_utils.py b/v2xvit/tools/train_utils.py new file mode 100644 index 0000000..d4fde47 --- /dev/null +++ b/v2xvit/tools/train_utils.py @@ -0,0 +1,224 @@ +import glob +import importlib +import yaml +import os +import re +from datetime import datetime + +import torch +import torch.optim as optim + + +def load_saved_model(saved_path, model): + """ + Load saved model if exiseted + + Parameters + __________ + saved_path : str + model saved path + model : opencood object + The model instance. + + Returns + ------- + model : opencood object + The model instance loaded pretrained params. + """ + assert os.path.exists(saved_path), '{} not found'.format(saved_path) + + def findLastCheckpoint(save_dir): + file_list = glob.glob(os.path.join(save_dir, '*epoch*.pth')) + if file_list: + epochs_exist = [] + for file_ in file_list: + result = re.findall(".*epoch(.*).pth.*", file_) + epochs_exist.append(int(result[0])) + initial_epoch_ = max(epochs_exist) + else: + initial_epoch_ = 0 + return initial_epoch_ + + initial_epoch = findLastCheckpoint(saved_path) + if initial_epoch > 0: + print('resuming by loading epoch %d' % initial_epoch) + model.load_state_dict(torch.load( + os.path.join(saved_path, + 'net_epoch%d.pth' % initial_epoch)), strict=False) + + return initial_epoch, model + + +def setup_train(hypes): + """ + Create folder for saved model based on current timestep and model name + + Parameters + ---------- + hypes: dict + Config yaml dictionary for training: + """ + model_name = hypes['name'] + current_time = datetime.now() + + folder_name = current_time.strftime("_%Y_%m_%d_%H_%M_%S") + folder_name = model_name + folder_name + + current_path = os.path.dirname(__file__) + current_path = os.path.join(current_path, '../logs') + + full_path = os.path.join(current_path, folder_name) + + if not os.path.exists(full_path): + os.makedirs(full_path) + # save the yaml file + save_name = os.path.join(full_path, 'config.yaml') + with open(save_name, 'w') as outfile: + yaml.dump(hypes, outfile) + + return full_path + + +def create_model(hypes): + """ + Import the module "models/[model_name].py + + Parameters + __________ + hypes : dict + Dictionary containing parameters. + + Returns + ------- + model : opencood,object + Model object. + """ + backbone_name = hypes['model']['core_method'] + backbone_config = hypes['model']['args'] + + model_filename = "v2xvit.models." + backbone_name + model_lib = importlib.import_module(model_filename) + model = None + target_model_name = backbone_name.replace('_', '') + + for name, cls in model_lib.__dict__.items(): + if name.lower() == target_model_name.lower(): + model = cls + + if model is None: + print('backbone not found in models folder. Please make sure you ' + 'have a python file named %s and has a class ' + 'called %s ignoring upper/lower case' % (model_filename, + target_model_name)) + exit(0) + instance = model(backbone_config) + return instance + + +def create_loss(hypes): + """ + Create the loss function based on the given loss name. + + Parameters + ---------- + hypes : dict + Configuration params for training. + Returns + ------- + criterion : opencood.object + The loss function. + """ + loss_func_name = hypes['loss']['core_method'] + loss_func_config = hypes['loss']['args'] + + loss_filename = "opencood.loss." + loss_func_name + loss_lib = importlib.import_module(loss_filename) + loss_func = None + target_loss_name = loss_func_name.replace('_', '') + + for name, lfunc in loss_lib.__dict__.items(): + if name.lower() == target_loss_name.lower(): + loss_func = lfunc + + if loss_func is None: + print('loss function not found in loss folder. Please make sure you ' + 'have a python file named %s and has a class ' + 'called %s ignoring upper/lower case' % (loss_filename, + target_loss_name)) + exit(0) + + criterion = loss_func(loss_func_config) + return criterion + + +def setup_optimizer(hypes, model): + """ + Create optimizer corresponding to the yaml file + + Parameters + ---------- + hypes : dict + The training configurations. + model : opencood model + The pytorch model + """ + method_dict = hypes['optimizer'] + optimizer_method = getattr(optim, method_dict['core_method'], None) + if not optimizer_method: + raise ValueError('{} is not supported'.format(method_dict['name'])) + if 'args' in method_dict: + return optimizer_method(filter(lambda p: p.requires_grad, + model.parameters()), + lr=method_dict['lr'], + **method_dict['args']) + else: + return optimizer_method(filter(lambda p: p.requires_grad, + model.parameters()), + lr=method_dict['lr']) + + +def setup_lr_schedular(hypes, optimizer): + """ + Set up the learning rate schedular. + + Parameters + ---------- + hypes : dict + The training configurations. + + optimizer : torch.optimizer + """ + lr_schedule_config = hypes['lr_scheduler'] + + if lr_schedule_config['core_method'] == 'step': + from torch.optim.lr_scheduler import StepLR + step_size = lr_schedule_config['step_size'] + gamma = lr_schedule_config['gamma'] + scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) + + elif lr_schedule_config['core_method'] == 'multistep': + from torch.optim.lr_scheduler import MultiStepLR + milestones = lr_schedule_config['step_size'] + gamma = lr_schedule_config['gamma'] + scheduler = MultiStepLR(optimizer, + milestones=milestones, + gamma=gamma) + + else: + from torch.optim.lr_scheduler import ExponentialLR + gamma = lr_schedule_config['gamma'] + scheduler = ExponentialLR(optimizer, gamma) + + return scheduler + + +def to_device(inputs, device): + if isinstance(inputs, list): + return [to_device(x, device) for x in inputs] + elif isinstance(inputs, dict): + return {k: to_device(v, device) for k, v in inputs.items()} + else: + if isinstance(inputs, int) or isinstance(inputs, float) \ + or isinstance(inputs, str): + return inputs + return inputs.to(device) diff --git a/v2xvit/utils/box_utils.py b/v2xvit/utils/box_utils.py index 763a290..9dfe44d 100644 --- a/v2xvit/utils/box_utils.py +++ b/v2xvit/utils/box_utils.py @@ -339,7 +339,7 @@ def get_mask_for_boxes_within_range_torch(boxes): bbx is within the range and False means the bbx is outside the range. """ - from opencood.data_utils.datasets import GT_RANGE + from v2xvit.data_utils.datasets import GT_RANGE # mask out the gt bounding box out fixed range (-140, -40, -3, 140, 40 1) device = boxes.device