diff --git a/src/transformers/configuration_fsmt.py b/src/transformers/configuration_fsmt.py index f47244daa15263..3490ca0e70ad00 100644 --- a/src/transformers/configuration_fsmt.py +++ b/src/transformers/configuration_fsmt.py @@ -68,10 +68,6 @@ Typically set this to something large just in case (e.g., 512 or 1024 or 2048). init_std (:obj:`float`, optional, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - normalize_before (:obj:`bool`, optional, defaults to :obj:`False`): - Call layernorm before attention ops. - add_final_layer_norm (:obj:`bool`, optional, defaults to :obj:`False`): - Why not add another layernorm? scale_embedding (:obj:`bool`, optional, defaults to :obj:`True`): Scale embeddings by diving by sqrt(d_model). bos_token_id (:obj:`int`, optional, defaults to 0) @@ -140,9 +136,7 @@ def __init__( bos_token_id=0, eos_token_id=2, decoder_start_token_id=2, - add_final_layer_norm=False, is_encoder_decoder=True, - normalize_before=False, scale_embedding=True, tie_word_embeddings=False, **common_kwargs @@ -191,8 +185,6 @@ def __init__( # Params introduced for Mbart self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True - self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before - self.add_final_layer_norm = add_final_layer_norm # 3 Types of Dropout self.attention_dropout = attention_dropout diff --git a/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py index d88e4496f0b564..62ac3514f755a5 100755 --- a/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py @@ -319,9 +319,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder "bos_token_id": 0, "pad_token_id": 1, "eos_token_id": 2, - "add_final_layer_norm": False, "is_encoder_decoder": True, - "normalize_before": False, "scale_embedding": True, "tie_word_embeddings": False, } diff --git a/src/transformers/modeling_fsmt.py b/src/transformers/modeling_fsmt.py index 5706c89023fd48..ede39e76512a98 100644 --- a/src/transformers/modeling_fsmt.py +++ b/src/transformers/modeling_fsmt.py @@ -337,7 +337,6 @@ def __init__(self, config: FSMTConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout) - self.normalize_before = config.normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -359,26 +358,20 @@ def forward(self, x, encoder_padding_mask, output_attentions=False): encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x - if self.normalize_before: - x = self.self_attn_layer_norm(x) x, attn_weights = self.self_attn( query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - if not self.normalize_before: - x = self.self_attn_layer_norm(x) + x = self.self_attn_layer_norm(x) residual = x - if self.normalize_before: - x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - if not self.normalize_before: - x = self.final_layer_norm(x) + x = self.final_layer_norm(x) return x, attn_weights @@ -411,8 +404,6 @@ def __init__(self, config: FSMTConfig, embed_tokens): init_size=num_embeddings + self.padding_idx + 1, # removed: config.max_position_embeddings ) self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) - # mbart has one extra layer_norm - self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None def forward( self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False @@ -459,8 +450,6 @@ def forward( if output_attentions: all_attentions = all_attentions + (attn,) - if self.layer_norm: - x = self.layer_norm(x) if output_hidden_states: encoder_states.append(x) # T x B x C -> B x T x C @@ -487,7 +476,6 @@ def __init__(self, config: FSMTConfig): self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout - self.normalize_before = config.normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.encoder_attn = Attention( @@ -515,10 +503,8 @@ def forward( if layer_state is None: layer_state = {} - if self.normalize_before: - x = self.self_attn_layer_norm(x) - # Self Attention + # Self Attention x, self_attn_weights = self.self_attn( query=x, key=x, @@ -529,14 +515,11 @@ def forward( ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - if not self.normalize_before: - x = self.self_attn_layer_norm(x) + x = self.self_attn_layer_norm(x) # Cross attention residual = x assert self.encoder_attn.cache_key != self.self_attn.cache_key - if self.normalize_before: - x = self.encoder_attn_layer_norm(x) x, _ = self.encoder_attn( query=x, key=encoder_hidden_states, @@ -545,20 +528,16 @@ def forward( ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - if not self.normalize_before: - x = self.encoder_attn_layer_norm(x) + x = self.encoder_attn_layer_norm(x) # Fully Connected residual = x - if self.normalize_before: - x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - if not self.normalize_before: - x = self.final_layer_norm(x) + x = self.final_layer_norm(x) return ( x, self_attn_weights, @@ -593,7 +572,6 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding): self.layers = nn.ModuleList( [DecoderLayer(config) for _ in range(config.decoder_layers)] ) # type: List[DecoderLayer] - self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None self.output_projection = nn.Linear( self.embed_tokens.weight.shape[1], @@ -695,8 +673,6 @@ def forward( if use_cache: next_decoder_cache.append(layer_past.copy()) - if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart - x = self.layer_norm(x) if output_attentions: all_self_attns += (layer_self_attn,)