From 4e6f03e0c8d9380e1cb85cbd02e6d5c474c07abf Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Wed, 8 Dec 2021 13:33:41 +0000 Subject: [PATCH] Remove decoder from model --- transformer_models/ViT.py | 84 +++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/transformer_models/ViT.py b/transformer_models/ViT.py index dace6d1..4dec168 100644 --- a/transformer_models/ViT.py +++ b/transformer_models/ViT.py @@ -120,48 +120,48 @@ def __init__( self.to_cls_token = nn.Identity() # Decoder - factor = 1 # 5 - dropout = args.decoder_attn_dropout_rate - # d_model = args.decoder_embedding_dim - n_heads = args.decoder_num_heads - d_layers = args.decoder_layers - d_ff = ( - args.decoder_embedding_dim_out - ) # args.decoder_embedding_dim_out or 4*args.decoder_embedding_dim None - activation = "gelu" # 'gelu' - self.decoder = Decoder( - [ - DecoderLayer( - AttentionLayer( - FullAttention(True, factor, attention_dropout=dropout), # True - d_model, - n_heads, - ), # ProbAttention FullAttention - AttentionLayer( - FullAttention( - False, factor, attention_dropout=dropout - ), # False - d_model, - n_heads, - ), - d_model, - d_ff, - dropout=dropout, - activation=activation, - ) - for l in range(d_layers) - ], - norm_layer=torch.nn.LayerNorm(d_model), - ) - self.decoder_cls_token = nn.Parameter(torch.zeros(1, args.query_num, d_model)) - if positional_encoding_type == "learned": - self.decoder_position_encoding = LearnedPositionalEncoding( - args.query_num, self.embedding_dim, args.query_num - ) - elif positional_encoding_type == "fixed": - self.decoder_position_encoding = FixedPositionalEncoding( - self.embedding_dim, - ) + # factor = 1 # 5 + # dropout = args.decoder_attn_dropout_rate + # # d_model = args.decoder_embedding_dim + # n_heads = args.decoder_num_heads + # d_layers = args.decoder_layers + # d_ff = ( + # args.decoder_embedding_dim_out + # ) # args.decoder_embedding_dim_out or 4*args.decoder_embedding_dim None + # activation = "gelu" # 'gelu' + # self.decoder = Decoder( + # [ + # DecoderLayer( + # AttentionLayer( + # FullAttention(True, factor, attention_dropout=dropout), # True + # d_model, + # n_heads, + # ), # ProbAttention FullAttention + # AttentionLayer( + # FullAttention( + # False, factor, attention_dropout=dropout + # ), # False + # d_model, + # n_heads, + # ), + # d_model, + # d_ff, + # dropout=dropout, + # activation=activation, + # ) + # for l in range(d_layers) + # ], + # norm_layer=torch.nn.LayerNorm(d_model), + # ) + # self.decoder_cls_token = nn.Parameter(torch.zeros(1, args.query_num, d_model)) + # if positional_encoding_type == "learned": + # self.decoder_position_encoding = LearnedPositionalEncoding( + # args.query_num, self.embedding_dim, args.query_num + # ) + # elif positional_encoding_type == "fixed": + # self.decoder_position_encoding = FixedPositionalEncoding( + # self.embedding_dim, + # ) print("position decoding :", positional_encoding_type) self.classifier = nn.Linear(d_model, out_dim) self.after_dropout = nn.Dropout(p=self.dropout_rate)