Skip to content

Commit

Permalink
Remove decoder from model
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasHedegaard committed Dec 8, 2021
1 parent 9ae64df commit 4e6f03e
Showing 1 changed file with 42 additions and 42 deletions.
84 changes: 42 additions & 42 deletions transformer_models/ViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,48 +120,48 @@ def __init__(
self.to_cls_token = nn.Identity()

# Decoder
factor = 1 # 5
dropout = args.decoder_attn_dropout_rate
# d_model = args.decoder_embedding_dim
n_heads = args.decoder_num_heads
d_layers = args.decoder_layers
d_ff = (
args.decoder_embedding_dim_out
) # args.decoder_embedding_dim_out or 4*args.decoder_embedding_dim None
activation = "gelu" # 'gelu'
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(
FullAttention(True, factor, attention_dropout=dropout), # True
d_model,
n_heads,
), # ProbAttention FullAttention
AttentionLayer(
FullAttention(
False, factor, attention_dropout=dropout
), # False
d_model,
n_heads,
),
d_model,
d_ff,
dropout=dropout,
activation=activation,
)
for l in range(d_layers)
],
norm_layer=torch.nn.LayerNorm(d_model),
)
self.decoder_cls_token = nn.Parameter(torch.zeros(1, args.query_num, d_model))
if positional_encoding_type == "learned":
self.decoder_position_encoding = LearnedPositionalEncoding(
args.query_num, self.embedding_dim, args.query_num
)
elif positional_encoding_type == "fixed":
self.decoder_position_encoding = FixedPositionalEncoding(
self.embedding_dim,
)
# factor = 1 # 5
# dropout = args.decoder_attn_dropout_rate
# # d_model = args.decoder_embedding_dim
# n_heads = args.decoder_num_heads
# d_layers = args.decoder_layers
# d_ff = (
# args.decoder_embedding_dim_out
# ) # args.decoder_embedding_dim_out or 4*args.decoder_embedding_dim None
# activation = "gelu" # 'gelu'
# self.decoder = Decoder(
# [
# DecoderLayer(
# AttentionLayer(
# FullAttention(True, factor, attention_dropout=dropout), # True
# d_model,
# n_heads,
# ), # ProbAttention FullAttention
# AttentionLayer(
# FullAttention(
# False, factor, attention_dropout=dropout
# ), # False
# d_model,
# n_heads,
# ),
# d_model,
# d_ff,
# dropout=dropout,
# activation=activation,
# )
# for l in range(d_layers)
# ],
# norm_layer=torch.nn.LayerNorm(d_model),
# )
# self.decoder_cls_token = nn.Parameter(torch.zeros(1, args.query_num, d_model))
# if positional_encoding_type == "learned":
# self.decoder_position_encoding = LearnedPositionalEncoding(
# args.query_num, self.embedding_dim, args.query_num
# )
# elif positional_encoding_type == "fixed":
# self.decoder_position_encoding = FixedPositionalEncoding(
# self.embedding_dim,
# )
print("position decoding :", positional_encoding_type)
self.classifier = nn.Linear(d_model, out_dim)
self.after_dropout = nn.Dropout(p=self.dropout_rate)
Expand Down

0 comments on commit 4e6f03e

Please sign in to comment.