Skip to content

Commit

Permalink
Fix CoTrans flops and training
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasHedegaard committed Dec 14, 2021
1 parent 24209db commit d33fbb7
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 148 deletions.
20 changes: 8 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def main(args):
np.random.seed(seed)
random.seed(seed)

model = transformer_models.VisionTransformer_v3(
# model = transformer_models.VisionTransformer_v3(
model = transformer_models.CoVisionTransformer(
args=args,
img_dim=args.enc_layers, # VisionTransformer_v3
img_dim=args.enc_layers,
patch_dim=args.patch_dim,
out_dim=args.numclass,
embedding_dim=args.embedding_dim,
Expand All @@ -78,18 +79,12 @@ def main(args):
try:
from ptflops import get_model_complexity_info

def input_constructor(*largs, **lkwargs):
return {
"sequence_input_rgb": torch.ones(()).new_empty(
(1, args.enc_layers, args.dim_feature // 3 * 2)
),
"sequence_input_flow": torch.ones(()).new_empty(
(1, args.enc_layers, args.dim_feature // 3)
),
}
# Warm up model
model.forward_steps(torch.randn(1, args.dim_feature, args.enc_layers))
model.call_mode = "forward_step"

flops, params = get_model_complexity_info(
model, (0, 0), input_constructor=input_constructor, as_strings=False
model, (args.dim_feature,), as_strings=False
)
print(f"Model FLOPs: {flops}")
print(f"Model params: {params}")
Expand All @@ -98,6 +93,7 @@ def input_constructor(*largs, **lkwargs):
print(e)
...

model.call_mode = "forward"
model.to(device)

loss_need = [
Expand Down
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def train_one_epoch(
class_h_target = class_h_target.to(device)
dec_target = dec_target.to(device)

enc_score_p0 = model(camera_inputs, motion_inputs)
# enc_score_p0 = model(camera_inputs, motion_inputs)
enc_score_p0 = model(
torch.cat((camera_inputs, motion_inputs), 2).transpose(1, 2)
)

outputs = {
"labels_encoder": enc_score_p0, # [128, 22]
Expand Down
36 changes: 1 addition & 35 deletions transformer_models/Transformer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from torch import nn
from .Attention import SelfAttention
import continual as co
from continual_transformers.co_si_trans_enc import CoSiTransformerEncoder
from continual_transformers.co_resi_trans_enc import CoReSiTransformerEncoder
from continual_transformers import CoSiTransformerEncoder, CoReSiTransformerEncoder


class Residual(nn.Module):
Expand Down Expand Up @@ -122,39 +121,6 @@ def CoTransformerModel(
dtype=None,
)

# layers = []

# for d in range(depth):
# CoMHA = (
# CoReMultiheadAttention if d == 0 and depth == 2 else CoSiMultiheadAttention
# )
# layers.extend(
# [
# co.Residual(
# co.Sequential(
# co.forward_stepping(nn.LayerNorm(dim)),
# CoMHA(
# dim,
# heads,
# attn_dropout_rate,
# sequence_len=sequence_len,
# forward_returns_attn_mask=False,
# ),
# nn.Dropout(p=dropout_rate),
# )
# ),
# co.Residual(
# co.Sequential(
# co.forward_stepping(nn.LayerNorm(dim)),
# co.forward_stepping(FeedForward(dim, mlp_dim, dropout_rate)),
# )
# ),
# ]
# )

# net = co.Sequential(*layers)
# return net


def _register_ptflops():
try:
Expand Down
57 changes: 8 additions & 49 deletions transformer_models/ViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
import torch
from torch.functional import Tensor
import torch.nn as nn
import torch.nn.functional as F
from .decoder import Decoder, DecoderLayer
from .attn import FullAttention, ProbAttention, AttentionLayer
from .Transformer import TransformerModel, CoTransformerModel
from ipdb import set_trace
from .PositionalEncoding import (
FixedPositionalEncoding,
LearnedPositionalEncoding,
ShiftingLearnedPositionalEncoding,
)
import continual as co
from continual_transformers import CircularPositionalEncoding

__all__ = ["ViT_B16", "ViT_B32", "ViT_L16", "ViT_L32", "ViT_H14"]

Expand Down Expand Up @@ -43,19 +40,11 @@ def CoVisionTransformer(
seq_length = num_patches # no class token
flatten_dim = patch_dim * patch_dim * num_channels

linear_encoding = nn.Linear(flatten_dim, embedding_dim)
if positional_encoding_type == "learned":
position_encoding = LearnedPositionalEncoding(
seq_length, embedding_dim, seq_length
)
elif positional_encoding_type == "fixed":
position_encoding = FixedPositionalEncoding(
embedding_dim,
)
if positional_encoding_type == "shifting_learned":
position_encoding = ShiftingLearnedPositionalEncoding(
2 * seq_length, embedding_dim, seq_length
)
linear_encoding = co.Linear(flatten_dim, embedding_dim, channel_dim=1)
assert positional_encoding_type == "shifting_learned"
position_encoding = CircularPositionalEncoding(
embedding_dim, 2 * seq_length, forward_update_index_steps=1
)
print("position encoding :", positional_encoding_type)

pe_dropout = nn.Dropout(p=dropout_rate)
Expand All @@ -68,21 +57,10 @@ def CoVisionTransformer(
dropout_rate,
attn_dropout_rate,
)
pre_head_ln = nn.LayerNorm(embedding_dim)
mlp_head = nn.Linear(hidden_dim, out_dim)

def concat_inputs(inputs: Tuple[Tensor, Tensor]) -> Tensor:
sequence_input_rgb, sequence_input_flow = inputs
if with_camera and with_motion:
x = torch.cat((sequence_input_rgb, sequence_input_flow), 2)
elif with_camera:
x = sequence_input_rgb
elif with_motion:
x = sequence_input_flow
return x
pre_head_ln = co.Lambda(nn.LayerNorm(embedding_dim), takes_time=False)
mlp_head = co.Linear(hidden_dim, out_dim, channel_dim=1)

return co.Sequential(
co.Lambda(concat_inputs),
linear_encoding,
position_encoding,
pe_dropout,
Expand Down Expand Up @@ -160,7 +138,6 @@ def __init__(
)
self.pre_head_ln = nn.LayerNorm(embedding_dim)

d_model = args.decoder_embedding_dim
use_representation = False # False
if use_representation:
self.mlp_head = nn.Sequential(
Expand All @@ -172,24 +149,6 @@ def __init__(
else:
self.mlp_head = nn.Linear(hidden_dim, out_dim)

if self.conv_patch_representation:
self.conv_x = nn.Conv1d(
self.num_channels,
self.embedding_dim,
kernel_size=self.patch_dim,
stride=self.patch_dim,
padding=self._get_padding(
"VALID",
(self.patch_dim),
),
)
else:
self.conv_x = None

self.to_cls_token = nn.Identity()
self.classifier = nn.Linear(d_model, out_dim)
self.after_dropout = nn.Dropout(p=self.dropout_rate)

def forward(self, inputs: Tuple[Tensor, Tensor]):
sequence_input_rgb, sequence_input_flow = inputs
if self.with_camera and self.with_motion:
Expand Down
2 changes: 1 addition & 1 deletion transformer_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .ViT import VisionTransformer_v3
from .ViT import VisionTransformer_v3, CoVisionTransformer
from .HybridViT import ResNetHybridViT, AxialNetHybridViT

__all__ = ['ResNetHybridViT', 'AxialNetHybridViT', 'VisionTransformer_v3']
93 changes: 43 additions & 50 deletions util/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@


class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
"""This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""

def __init__(self, num_classes, losses, args):
""" Create the criterion.
"""Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
Expand All @@ -25,11 +26,11 @@ def __init__(self, num_classes, losses, args):
self.classification_h_loss_coef = args.classification_h_loss_coef
self.similar_loss_coef = args.similar_loss_coef
self.weight_dict = {
'labels_encoder': self.classification_h_loss_coef,
'labels_decoder': args.classification_pred_loss_coef,
'labels_x0': self.classification_x_loss_coef,
'labels_xt': self.classification_x_loss_coef,
'distance': self.similar_loss_coef,
"labels_encoder": self.classification_h_loss_coef,
"labels_decoder": args.classification_pred_loss_coef,
"labels_x0": self.classification_x_loss_coef,
"labels_xt": self.classification_x_loss_coef,
"distance": self.similar_loss_coef,
}
self.losses = losses
self.ignore_index = 21
Expand All @@ -41,24 +42,20 @@ def loss_labels(self, input, targets, name):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
# assert 'pred_logits' in outputs
# src_logits = outputs['pred_logits']
#
# idx = self._get_src_permutation_idx(indices)
# target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
# target_classes = torch.full(src_logits.shape[:2], self.num_classes,
# dtype=torch.int64, device=src_logits.device)
# target_classes[idx] = target_classes_o

# loss_ce = F.cross_entropy(outputs, targets, ignore_index=21)
target = targets.float()
# logsoftmax = nn.LogSoftmax(dim=1).to(input.device)

if len(input.shape) > 2 and input.shape[-1] == 1:
input = input.squeeze(-1)

if self.ignore_index >= 0:
notice_index = [i for i in range(target.shape[-1]) if i != self.ignore_index]
output = torch.sum(-target[:, notice_index] * self.logsoftmax(input[:, notice_index]), 1)
if output.sum() == 0: # 全为 ignore 类
loss_ce = torch.tensor(0.).to(input.device).type_as(target)
notice_index = [
i for i in range(target.shape[-1]) if i != self.ignore_index
]
output = torch.sum(
-target[:, notice_index] * self.logsoftmax(input[:, notice_index]), 1
)
if output.sum() == 0: # 全为 ignore 类
loss_ce = torch.tensor(0.0).to(input.device).type_as(target)
else:
loss_ce = torch.mean(output[target[:, self.ignore_index] != 1])
else:
Expand All @@ -67,7 +64,7 @@ def loss_labels(self, input, targets, name):
loss_ce = torch.mean(output)
else:
loss_ce = torch.sum(output)
if torch.isnan(loss_ce).sum()>0:
if torch.isnan(loss_ce).sum() > 0:
set_trace()
losses = {name: loss_ce}

Expand All @@ -77,24 +74,17 @@ def loss_labels_decoder(self, input, targets, name):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
# assert 'pred_logits' in outputs
# src_logits = outputs['pred_logits']
#
# idx = self._get_src_permutation_idx(indices)
# target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
# target_classes = torch.full(src_logits.shape[:2], self.num_classes,
# dtype=torch.int64, device=src_logits.device)
# target_classes[idx] = target_classes_o

# loss_ce = F.cross_entropy(outputs, targets, ignore_index=21)
target = targets.float()
# logsoftmax = nn.LogSoftmax(dim=1).to(input.device)
ignore_index = 21 # -1 改为21 更好一点
ignore_index = 21
if ignore_index >= 0:
notice_index = [i for i in range(target.shape[-1]) if i != self.ignore_index]
output = torch.sum(-target[:, notice_index] * self.logsoftmax(input[:, notice_index]), 1)
if output.sum() == 0: # 全为 ignore 类
loss_ce = torch.tensor(0.).to(input.device).type_as(target)
notice_index = [
i for i in range(target.shape[-1]) if i != self.ignore_index
]
output = torch.sum(
-target[:, notice_index] * self.logsoftmax(input[:, notice_index]), 1
)
if output.sum() == 0: # 全为 ignore 类
loss_ce = torch.tensor(0.0).to(input.device).type_as(target)
else:
loss_ce = torch.mean(output[target[:, self.ignore_index] != 1])
else:
Expand All @@ -103,7 +93,7 @@ def loss_labels_decoder(self, input, targets, name):
loss_ce = torch.mean(output)
else:
loss_ce = torch.sum(output)
if torch.isnan(loss_ce).sum()>0:
if torch.isnan(loss_ce).sum() > 0:
set_trace()
losses = {name: loss_ce}

Expand All @@ -116,26 +106,29 @@ def contrastive_loss(self, output, label, name):
"""
output1, output2 = output
euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
loss_contrastive = torch.mean((1.-label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
if torch.isnan(loss_contrastive).sum()>0:
loss_contrastive = torch.mean(
(1.0 - label) * torch.pow(euclidean_distance, 2)
+ (label)
* torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
)
if torch.isnan(loss_contrastive).sum() > 0:
set_trace()
losses = {name: loss_contrastive.double()}
return losses

def get_loss(self, loss, outputs, targets):
loss_map = {
'labels_encoder': self.loss_labels,
'labels_decoder': self.loss_labels_decoder,
'labels_x0': self.loss_labels,
'labels_xt': self.loss_labels,
'distance': self.contrastive_loss,
"labels_encoder": self.loss_labels,
"labels_decoder": self.loss_labels_decoder,
"labels_x0": self.loss_labels,
"labels_xt": self.loss_labels,
"distance": self.contrastive_loss,
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
assert loss in loss_map, f"do you really want to compute {loss} loss?"
return loss_map[loss](outputs, targets, name=loss)

def forward(self, outputs, targets):
""" This performs the loss computation.
"""This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
Expand Down

0 comments on commit d33fbb7

Please sign in to comment.