From 63473e6a50c8a6b810cb404db5dc7c0dcc5525cc Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Thu, 30 Dec 2021 12:37:55 +0000 Subject: [PATCH] Fix cooadtr merge --- main.py | 23 ++++++++--------- transformer_models/Transformer.py | 41 +++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 5bbc368..ce9f26f 100644 --- a/main.py +++ b/main.py @@ -59,9 +59,10 @@ def main(args): np.random.seed(seed) random.seed(seed) - model = transformer_models.VisionTransformer_v3( + # model = transformer_models.VisionTransformer_v3( + model = transformer_models.CoVisionTransformer( args=args, - img_dim=args.enc_layers, # VisionTransformer_v3 + img_dim=args.enc_layers, patch_dim=args.patch_dim, out_dim=args.numclass, embedding_dim=args.embedding_dim, @@ -78,18 +79,12 @@ def main(args): try: from ptflops import get_model_complexity_info - def input_constructor(*largs, **lkwargs): - return { - "sequence_input_rgb": torch.ones(()).new_empty( - (1, args.enc_layers, args.dim_feature // 3 * 2) - ), - "sequence_input_flow": torch.ones(()).new_empty( - (1, args.enc_layers, args.dim_feature // 3) - ), - } + # Warm up model + model.forward_steps(torch.randn(1, args.dim_feature, args.enc_layers)) + model.call_mode = "forward_step" flops, params = get_model_complexity_info( - model, (0, 0), input_constructor=input_constructor, as_strings=False + model, (args.dim_feature,), as_strings=False ) print(f"Model FLOPs: {flops}") print(f"Model params: {params}") @@ -98,11 +93,12 @@ def input_constructor(*largs, **lkwargs): print(e) ... + model.call_mode = "forward" model.to(device) loss_need = [ "labels_encoder", - "labels_decoder", + # "labels_decoder", ] criterion = utl.SetCriterion( num_classes=args.numclass, losses=loss_need, args=args @@ -200,6 +196,7 @@ def input_constructor(*largs, **lkwargs): for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) + train_stats = train_one_epoch( model, criterion, diff --git a/transformer_models/Transformer.py b/transformer_models/Transformer.py index cf66e9a..e06f792 100644 --- a/transformer_models/Transformer.py +++ b/transformer_models/Transformer.py @@ -1,5 +1,6 @@ from torch import nn from .Attention import SelfAttention +from continual_transformers import CoSiTransformerEncoder, CoReSiTransformerEncoder class Residual(nn.Module): @@ -80,6 +81,46 @@ def forward(self, x): return self.net(x) +def CoTransformerModel( + dim, + depth, + heads, + mlp_dim, + dropout_rate=0.1, + attn_dropout_rate=0.1, + sequence_len=64, +): + assert depth in {1, 2} + + if depth == 1: + return CoSiTransformerEncoder( + sequence_len=sequence_len, + embed_dim=dim, + num_heads=heads, + dropout=dropout_rate, + in_proj_bias=False, + query_index=-1, + ff_hidden_dim=mlp_dim, + ff_activation=nn.GELU(), + device=None, + dtype=None, + ) + + # depth == 2 + return CoReSiTransformerEncoder( + sequence_len=sequence_len, + embed_dim=dim, + num_heads=heads, + dropout=dropout_rate, + in_proj_bias=False, + query_index=-1, + ff_hidden_dim=mlp_dim, + ff_activation=nn.GELU(), + device=None, + dtype=None, + ) + + def _register_ptflops(): try: from ptflops import flops_counter as fc