From d33fbb75fe1c38b0fe9de170e29c4a10f027d324 Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Tue, 14 Dec 2021 06:32:40 +0000 Subject: [PATCH] Fix CoTrans flops and training --- main.py | 20 +++---- train.py | 5 +- transformer_models/Transformer.py | 36 +----------- transformer_models/ViT.py | 57 +++---------------- transformer_models/__init__.py | 2 +- util/loss.py | 93 ++++++++++++++----------------- 6 files changed, 65 insertions(+), 148 deletions(-) diff --git a/main.py b/main.py index 86ddf8f..c9aaf96 100644 --- a/main.py +++ b/main.py @@ -59,9 +59,10 @@ def main(args): np.random.seed(seed) random.seed(seed) - model = transformer_models.VisionTransformer_v3( + # model = transformer_models.VisionTransformer_v3( + model = transformer_models.CoVisionTransformer( args=args, - img_dim=args.enc_layers, # VisionTransformer_v3 + img_dim=args.enc_layers, patch_dim=args.patch_dim, out_dim=args.numclass, embedding_dim=args.embedding_dim, @@ -78,18 +79,12 @@ def main(args): try: from ptflops import get_model_complexity_info - def input_constructor(*largs, **lkwargs): - return { - "sequence_input_rgb": torch.ones(()).new_empty( - (1, args.enc_layers, args.dim_feature // 3 * 2) - ), - "sequence_input_flow": torch.ones(()).new_empty( - (1, args.enc_layers, args.dim_feature // 3) - ), - } + # Warm up model + model.forward_steps(torch.randn(1, args.dim_feature, args.enc_layers)) + model.call_mode = "forward_step" flops, params = get_model_complexity_info( - model, (0, 0), input_constructor=input_constructor, as_strings=False + model, (args.dim_feature,), as_strings=False ) print(f"Model FLOPs: {flops}") print(f"Model params: {params}") @@ -98,6 +93,7 @@ def input_constructor(*largs, **lkwargs): print(e) ... + model.call_mode = "forward" model.to(device) loss_need = [ diff --git a/train.py b/train.py index 9bec3e2..eb5638e 100644 --- a/train.py +++ b/train.py @@ -76,7 +76,10 @@ def train_one_epoch( class_h_target = class_h_target.to(device) dec_target = dec_target.to(device) - enc_score_p0 = model(camera_inputs, motion_inputs) + # enc_score_p0 = model(camera_inputs, motion_inputs) + enc_score_p0 = model( + torch.cat((camera_inputs, motion_inputs), 2).transpose(1, 2) + ) outputs = { "labels_encoder": enc_score_p0, # [128, 22] diff --git a/transformer_models/Transformer.py b/transformer_models/Transformer.py index d62518b..48f3b5e 100644 --- a/transformer_models/Transformer.py +++ b/transformer_models/Transformer.py @@ -1,8 +1,7 @@ from torch import nn from .Attention import SelfAttention import continual as co -from continual_transformers.co_si_trans_enc import CoSiTransformerEncoder -from continual_transformers.co_resi_trans_enc import CoReSiTransformerEncoder +from continual_transformers import CoSiTransformerEncoder, CoReSiTransformerEncoder class Residual(nn.Module): @@ -122,39 +121,6 @@ def CoTransformerModel( dtype=None, ) - # layers = [] - - # for d in range(depth): - # CoMHA = ( - # CoReMultiheadAttention if d == 0 and depth == 2 else CoSiMultiheadAttention - # ) - # layers.extend( - # [ - # co.Residual( - # co.Sequential( - # co.forward_stepping(nn.LayerNorm(dim)), - # CoMHA( - # dim, - # heads, - # attn_dropout_rate, - # sequence_len=sequence_len, - # forward_returns_attn_mask=False, - # ), - # nn.Dropout(p=dropout_rate), - # ) - # ), - # co.Residual( - # co.Sequential( - # co.forward_stepping(nn.LayerNorm(dim)), - # co.forward_stepping(FeedForward(dim, mlp_dim, dropout_rate)), - # ) - # ), - # ] - # ) - - # net = co.Sequential(*layers) - # return net - def _register_ptflops(): try: diff --git a/transformer_models/ViT.py b/transformer_models/ViT.py index 05c92c8..dfecfd5 100644 --- a/transformer_models/ViT.py +++ b/transformer_models/ViT.py @@ -2,17 +2,14 @@ import torch from torch.functional import Tensor import torch.nn as nn -import torch.nn.functional as F -from .decoder import Decoder, DecoderLayer -from .attn import FullAttention, ProbAttention, AttentionLayer from .Transformer import TransformerModel, CoTransformerModel -from ipdb import set_trace from .PositionalEncoding import ( FixedPositionalEncoding, LearnedPositionalEncoding, ShiftingLearnedPositionalEncoding, ) import continual as co +from continual_transformers import CircularPositionalEncoding __all__ = ["ViT_B16", "ViT_B32", "ViT_L16", "ViT_L32", "ViT_H14"] @@ -43,19 +40,11 @@ def CoVisionTransformer( seq_length = num_patches # no class token flatten_dim = patch_dim * patch_dim * num_channels - linear_encoding = nn.Linear(flatten_dim, embedding_dim) - if positional_encoding_type == "learned": - position_encoding = LearnedPositionalEncoding( - seq_length, embedding_dim, seq_length - ) - elif positional_encoding_type == "fixed": - position_encoding = FixedPositionalEncoding( - embedding_dim, - ) - if positional_encoding_type == "shifting_learned": - position_encoding = ShiftingLearnedPositionalEncoding( - 2 * seq_length, embedding_dim, seq_length - ) + linear_encoding = co.Linear(flatten_dim, embedding_dim, channel_dim=1) + assert positional_encoding_type == "shifting_learned" + position_encoding = CircularPositionalEncoding( + embedding_dim, 2 * seq_length, forward_update_index_steps=1 + ) print("position encoding :", positional_encoding_type) pe_dropout = nn.Dropout(p=dropout_rate) @@ -68,21 +57,10 @@ def CoVisionTransformer( dropout_rate, attn_dropout_rate, ) - pre_head_ln = nn.LayerNorm(embedding_dim) - mlp_head = nn.Linear(hidden_dim, out_dim) - - def concat_inputs(inputs: Tuple[Tensor, Tensor]) -> Tensor: - sequence_input_rgb, sequence_input_flow = inputs - if with_camera and with_motion: - x = torch.cat((sequence_input_rgb, sequence_input_flow), 2) - elif with_camera: - x = sequence_input_rgb - elif with_motion: - x = sequence_input_flow - return x + pre_head_ln = co.Lambda(nn.LayerNorm(embedding_dim), takes_time=False) + mlp_head = co.Linear(hidden_dim, out_dim, channel_dim=1) return co.Sequential( - co.Lambda(concat_inputs), linear_encoding, position_encoding, pe_dropout, @@ -160,7 +138,6 @@ def __init__( ) self.pre_head_ln = nn.LayerNorm(embedding_dim) - d_model = args.decoder_embedding_dim use_representation = False # False if use_representation: self.mlp_head = nn.Sequential( @@ -172,24 +149,6 @@ def __init__( else: self.mlp_head = nn.Linear(hidden_dim, out_dim) - if self.conv_patch_representation: - self.conv_x = nn.Conv1d( - self.num_channels, - self.embedding_dim, - kernel_size=self.patch_dim, - stride=self.patch_dim, - padding=self._get_padding( - "VALID", - (self.patch_dim), - ), - ) - else: - self.conv_x = None - - self.to_cls_token = nn.Identity() - self.classifier = nn.Linear(d_model, out_dim) - self.after_dropout = nn.Dropout(p=self.dropout_rate) - def forward(self, inputs: Tuple[Tensor, Tensor]): sequence_input_rgb, sequence_input_flow = inputs if self.with_camera and self.with_motion: diff --git a/transformer_models/__init__.py b/transformer_models/__init__.py index 59664ed..3d9b778 100644 --- a/transformer_models/__init__.py +++ b/transformer_models/__init__.py @@ -1,4 +1,4 @@ -from .ViT import VisionTransformer_v3 +from .ViT import VisionTransformer_v3, CoVisionTransformer from .HybridViT import ResNetHybridViT, AxialNetHybridViT __all__ = ['ResNetHybridViT', 'AxialNetHybridViT', 'VisionTransformer_v3'] diff --git a/util/loss.py b/util/loss.py index 1eb0577..13efcf9 100644 --- a/util/loss.py +++ b/util/loss.py @@ -5,13 +5,14 @@ class SetCriterion(nn.Module): - """ This class computes the loss for DETR. + """This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ + def __init__(self, num_classes, losses, args): - """ Create the criterion. + """Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals @@ -25,11 +26,11 @@ def __init__(self, num_classes, losses, args): self.classification_h_loss_coef = args.classification_h_loss_coef self.similar_loss_coef = args.similar_loss_coef self.weight_dict = { - 'labels_encoder': self.classification_h_loss_coef, - 'labels_decoder': args.classification_pred_loss_coef, - 'labels_x0': self.classification_x_loss_coef, - 'labels_xt': self.classification_x_loss_coef, - 'distance': self.similar_loss_coef, + "labels_encoder": self.classification_h_loss_coef, + "labels_decoder": args.classification_pred_loss_coef, + "labels_x0": self.classification_x_loss_coef, + "labels_xt": self.classification_x_loss_coef, + "distance": self.similar_loss_coef, } self.losses = losses self.ignore_index = 21 @@ -41,24 +42,20 @@ def loss_labels(self, input, targets, name): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ - # assert 'pred_logits' in outputs - # src_logits = outputs['pred_logits'] - # - # idx = self._get_src_permutation_idx(indices) - # target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) - # target_classes = torch.full(src_logits.shape[:2], self.num_classes, - # dtype=torch.int64, device=src_logits.device) - # target_classes[idx] = target_classes_o - - # loss_ce = F.cross_entropy(outputs, targets, ignore_index=21) target = targets.float() - # logsoftmax = nn.LogSoftmax(dim=1).to(input.device) + + if len(input.shape) > 2 and input.shape[-1] == 1: + input = input.squeeze(-1) if self.ignore_index >= 0: - notice_index = [i for i in range(target.shape[-1]) if i != self.ignore_index] - output = torch.sum(-target[:, notice_index] * self.logsoftmax(input[:, notice_index]), 1) - if output.sum() == 0: # 全为 ignore 类 - loss_ce = torch.tensor(0.).to(input.device).type_as(target) + notice_index = [ + i for i in range(target.shape[-1]) if i != self.ignore_index + ] + output = torch.sum( + -target[:, notice_index] * self.logsoftmax(input[:, notice_index]), 1 + ) + if output.sum() == 0: # 全为 ignore 类 + loss_ce = torch.tensor(0.0).to(input.device).type_as(target) else: loss_ce = torch.mean(output[target[:, self.ignore_index] != 1]) else: @@ -67,7 +64,7 @@ def loss_labels(self, input, targets, name): loss_ce = torch.mean(output) else: loss_ce = torch.sum(output) - if torch.isnan(loss_ce).sum()>0: + if torch.isnan(loss_ce).sum() > 0: set_trace() losses = {name: loss_ce} @@ -77,24 +74,17 @@ def loss_labels_decoder(self, input, targets, name): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ - # assert 'pred_logits' in outputs - # src_logits = outputs['pred_logits'] - # - # idx = self._get_src_permutation_idx(indices) - # target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) - # target_classes = torch.full(src_logits.shape[:2], self.num_classes, - # dtype=torch.int64, device=src_logits.device) - # target_classes[idx] = target_classes_o - - # loss_ce = F.cross_entropy(outputs, targets, ignore_index=21) target = targets.float() - # logsoftmax = nn.LogSoftmax(dim=1).to(input.device) - ignore_index = 21 # -1 改为21 更好一点 + ignore_index = 21 if ignore_index >= 0: - notice_index = [i for i in range(target.shape[-1]) if i != self.ignore_index] - output = torch.sum(-target[:, notice_index] * self.logsoftmax(input[:, notice_index]), 1) - if output.sum() == 0: # 全为 ignore 类 - loss_ce = torch.tensor(0.).to(input.device).type_as(target) + notice_index = [ + i for i in range(target.shape[-1]) if i != self.ignore_index + ] + output = torch.sum( + -target[:, notice_index] * self.logsoftmax(input[:, notice_index]), 1 + ) + if output.sum() == 0: # 全为 ignore 类 + loss_ce = torch.tensor(0.0).to(input.device).type_as(target) else: loss_ce = torch.mean(output[target[:, self.ignore_index] != 1]) else: @@ -103,7 +93,7 @@ def loss_labels_decoder(self, input, targets, name): loss_ce = torch.mean(output) else: loss_ce = torch.sum(output) - if torch.isnan(loss_ce).sum()>0: + if torch.isnan(loss_ce).sum() > 0: set_trace() losses = {name: loss_ce} @@ -116,26 +106,29 @@ def contrastive_loss(self, output, label, name): """ output1, output2 = output euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True) - loss_contrastive = torch.mean((1.-label) * torch.pow(euclidean_distance, 2) + - (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) - if torch.isnan(loss_contrastive).sum()>0: + loss_contrastive = torch.mean( + (1.0 - label) * torch.pow(euclidean_distance, 2) + + (label) + * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2) + ) + if torch.isnan(loss_contrastive).sum() > 0: set_trace() losses = {name: loss_contrastive.double()} return losses def get_loss(self, loss, outputs, targets): loss_map = { - 'labels_encoder': self.loss_labels, - 'labels_decoder': self.loss_labels_decoder, - 'labels_x0': self.loss_labels, - 'labels_xt': self.loss_labels, - 'distance': self.contrastive_loss, + "labels_encoder": self.loss_labels, + "labels_decoder": self.loss_labels_decoder, + "labels_x0": self.loss_labels, + "labels_xt": self.loss_labels, + "distance": self.contrastive_loss, } - assert loss in loss_map, f'do you really want to compute {loss} loss?' + assert loss in loss_map, f"do you really want to compute {loss} loss?" return loss_map[loss](outputs, targets, name=loss) def forward(self, outputs, targets): - """ This performs the loss computation. + """This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size.