Skip to content

Commit

Permalink
remove unused code #2
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Sep 8, 2020
1 parent 416fccf commit 05f09bb
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 40 deletions.
8 changes: 0 additions & 8 deletions src/transformers/configuration_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
init_std (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
normalize_before (:obj:`bool`, optional, defaults to :obj:`False`):
Call layernorm before attention ops.
add_final_layer_norm (:obj:`bool`, optional, defaults to :obj:`False`):
Why not add another layernorm?
scale_embedding (:obj:`bool`, optional, defaults to :obj:`True`):
Scale embeddings by diving by sqrt(d_model).
bos_token_id (:obj:`int`, optional, defaults to 0)
Expand Down Expand Up @@ -140,9 +136,7 @@ def __init__(
bos_token_id=0,
eos_token_id=2,
decoder_start_token_id=2,
add_final_layer_norm=False,
is_encoder_decoder=True,
normalize_before=False,
scale_embedding=True,
tie_word_embeddings=False,
**common_kwargs
Expand Down Expand Up @@ -191,8 +185,6 @@ def __init__(

# Params introduced for Mbart
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
self.add_final_layer_norm = add_final_layer_norm

# 3 Types of Dropout
self.attention_dropout = attention_dropout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
"bos_token_id": 0,
"pad_token_id": 1,
"eos_token_id": 2,
"add_final_layer_norm": False,
"is_encoder_decoder": True,
"normalize_before": False,
"scale_embedding": True,
"tie_word_embeddings": False,
}
Expand Down
36 changes: 6 additions & 30 deletions src/transformers/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ def __init__(self, config: FSMTConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
Expand All @@ -359,26 +358,20 @@ def forward(self, x, encoder_padding_mask, output_attentions=False):
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
x = self.self_attn_layer_norm(x)

residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
x = self.final_layer_norm(x)
return x, attn_weights


Expand Down Expand Up @@ -411,8 +404,6 @@ def __init__(self, config: FSMTConfig, embed_tokens):
init_size=num_embeddings + self.padding_idx + 1, # removed: config.max_position_embeddings
)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
# mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None

def forward(
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
Expand Down Expand Up @@ -459,8 +450,6 @@ def forward(
if output_attentions:
all_attentions = all_attentions + (attn,)

if self.layer_norm:
x = self.layer_norm(x)
if output_hidden_states:
encoder_states.append(x)
# T x B x C -> B x T x C
Expand All @@ -487,7 +476,6 @@ def __init__(self, config: FSMTConfig):
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before

self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = Attention(
Expand Down Expand Up @@ -515,10 +503,8 @@ def forward(

if layer_state is None:
layer_state = {}
if self.normalize_before:
x = self.self_attn_layer_norm(x)
# Self Attention

# Self Attention
x, self_attn_weights = self.self_attn(
query=x,
key=x,
Expand All @@ -529,14 +515,11 @@ def forward(
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
x = self.self_attn_layer_norm(x)

# Cross attention
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, _ = self.encoder_attn(
query=x,
key=encoder_hidden_states,
Expand All @@ -545,20 +528,16 @@ def forward(
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x = self.encoder_attn_layer_norm(x)

# Fully Connected
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
x = self.final_layer_norm(x)
return (
x,
self_attn_weights,
Expand Down Expand Up @@ -593,7 +572,6 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding):
self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None

self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1],
Expand Down Expand Up @@ -695,8 +673,6 @@ def forward(
if use_cache:
next_decoder_cache.append(layer_past.copy())

if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart
x = self.layer_norm(x)
if output_attentions:
all_self_attns += (layer_self_attn,)

Expand Down

0 comments on commit 05f09bb

Please sign in to comment.