Skip to content

Commit

Permalink
Enable Roberta AutoConversion, Align Paddle Roberta to HF Transformers (
Browse files Browse the repository at this point in the history
#4512)

* review

* tests

* changes

* changes

* Update test_modeling.py

* Update tests/transformers/roberta/test_modeling.py

Co-authored-by: 骑马小猫 <1435130236@qq.com>

---------

Co-authored-by: 骑马小猫 <1435130236@qq.com>
  • Loading branch information
sijunhe and wj-Mcat authored Jan 29, 2023
1 parent 12d15f9 commit c9734ff
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 25 deletions.
214 changes: 191 additions & 23 deletions paddlenlp/transformers/roberta/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import paddle.nn as nn
import paddle.nn.functional as F

from ...utils.converter import StateDictNameMapping
from .. import PretrainedModel, register_base_model
from ..model_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions,
Expand All @@ -43,6 +44,22 @@
]


def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: paddle.Tensor x:
Returns: paddle.Tensor
"""
if past_key_values_length is None:
past_key_values_length = 0
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = (input_ids != padding_idx).cast("int64")
incremental_indices = (paddle.cumsum(mask, axis=1) + past_key_values_length) * mask
return incremental_indices + padding_idx


class RobertaEmbeddings(nn.Layer):
r"""
Include embeddings from word, position and token_type embeddings.
Expand All @@ -60,15 +77,9 @@ def __init__(self, config: RobertaConfig):

def forward(self, input_ids, token_type_ids=None, position_ids=None, past_key_values_length=None):
if position_ids is None:
# maybe need use shape op to unify static graph and dynamic graph
ones = paddle.ones_like(input_ids, dtype="int64")
seq_length = paddle.cumsum(ones, axis=-1)
if self.cls_token_id == 0 or input_ids[0][0] == 0: # postion_ids for RobertaBPETokenizer
position_ids = seq_length + self.padding_idx + 1 - ones
else: # postion_ids for RobertaTokenizer
position_ids = seq_length - ones
if past_key_values_length is not None:
position_ids += past_key_values_length
position_ids = create_position_ids_from_input_ids(
input_ids, padding_idx=self.padding_idx, past_key_values_length=past_key_values_length
)
position_ids.stop_gradient = True
if token_type_ids is None:
token_type_ids = paddle.zeros_like(input_ids, dtype="int64")
Expand Down Expand Up @@ -123,6 +134,135 @@ class RobertaPretrainedModel(PretrainedModel):
}
base_model_prefix = "roberta"

@classmethod
def _get_name_mappings(cls, config: RobertaConfig) -> list[StateDictNameMapping]:
mappings = [
["embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"],
["embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"],
["embeddings.token_type_embeddings.weight", "embeddings.token_type_embeddings.weight"],
["embeddings.LayerNorm.weight", "embeddings.layer_norm.weight"],
["embeddings.LayerNorm.bias", "embeddings.layer_norm.bias"],
]

for layer_index in range(config.num_hidden_layers):
layer_mappings = [
[
f"encoder.layer.{layer_index}.attention.self.query.weight",
f"encoder.layers.{layer_index}.self_attn.q_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.self.query.bias",
f"encoder.layers.{layer_index}.self_attn.q_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.self.key.weight",
f"encoder.layers.{layer_index}.self_attn.k_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.self.key.bias",
f"encoder.layers.{layer_index}.self_attn.k_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.self.value.weight",
f"encoder.layers.{layer_index}.self_attn.v_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.self.value.bias",
f"encoder.layers.{layer_index}.self_attn.v_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.output.dense.weight",
f"encoder.layers.{layer_index}.self_attn.out_proj.weight",
"transpose",
],
[
f"encoder.layer.{layer_index}.attention.output.dense.bias",
f"encoder.layers.{layer_index}.self_attn.out_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.output.LayerNorm.weight",
f"encoder.layers.{layer_index}.norm1.weight",
],
[
f"encoder.layer.{layer_index}.attention.output.LayerNorm.bias",
f"encoder.layers.{layer_index}.norm1.bias",
],
[
f"encoder.layer.{layer_index}.intermediate.dense.weight",
f"encoder.layers.{layer_index}.linear1.weight",
"transpose",
],
[f"encoder.layer.{layer_index}.intermediate.dense.bias", f"encoder.layers.{layer_index}.linear1.bias"],
[
f"encoder.layer.{layer_index}.output.dense.weight",
f"encoder.layers.{layer_index}.linear2.weight",
"transpose",
],
[f"encoder.layer.{layer_index}.output.dense.bias", f"encoder.layers.{layer_index}.linear2.bias"],
[f"encoder.layer.{layer_index}.output.LayerNorm.weight", f"encoder.layers.{layer_index}.norm2.weight"],
[f"encoder.layer.{layer_index}.output.LayerNorm.bias", f"encoder.layers.{layer_index}.norm2.bias"],
]
mappings.extend(layer_mappings)

# Other than RobertaModel, other architectures will prepend model prefix
if config.architectures is not None and "RobertaModel" not in config.architectures:
for mapping in mappings:
mapping[0] = "roberta." + mapping[0]

if cls.__name__ != "RobertaModel":
for mapping in mappings:
mapping[1] = "roberta." + mapping[1]

mappings.extend(
[
["pooler.dense.weight", "roberta.pooler.dense.weight", "transpose"],
["pooler.dense.bias", "roberta.pooler.dense.bias"],
]
)

if config.architectures is not None:
if "RobertaForSequenceClassification" in config.architectures:
mappings.extend(
[
["classifier.out_proj.weight", "classifier.out_proj.weight", "transpose"],
["classifier.out_proj.bias", "classifier.out_proj.bias"],
["classifier.dense.weight", "classifier.dense.weight", "transpose"],
["classifier.dense.bias", "classifier.dense.bias"],
]
)
if "RobertaForMaskedLM" in config.architectures:
mappings.extend(
[
["lm_head.bias", "lm_head.bias"],
["lm_head.dense.weight", "lm_head.dense.weight"],
["lm_head.dense.bias", "lm_head.dense.bias"],
["lm_head.layer_norm.weight", "lm_head.layer_norm.weight"],
["lm_head.layer_norm.bias", "lm_head.layer_norm.bias"],
]
)
if (
"RobertaForTokenClassification" in config.architectures
or "RobertaForMultipleChoice" in config.architectures
):
mappings.extend(
[
["classifier.weight", "classifier.weight", "transpose"],
["classifier.bias", "classifier.bias"],
]
)
if "RobertaForQuestionAnswering" in config.architectures:
mappings.extend(
[
["qa_outputs.weight", "classifier.weight", "transpose"],
["qa_outputs.bias", "classifier.bias"],
]
)

return [StateDictNameMapping(*mapping) for mapping in mappings]

def init_weights(self, layer):
"""Initialization hook"""
if isinstance(layer, (nn.Linear, nn.Embedding)):
Expand Down Expand Up @@ -202,7 +342,7 @@ class RobertaModel(RobertaPretrainedModel):
Defaults to `101`.
"""

def __init__(self, config: RobertaConfig):
def __init__(self, config: RobertaConfig, add_pooling_layer=True):
super(RobertaModel, self).__init__(config)

self.pad_token_id = config.pad_token_id
Expand All @@ -219,7 +359,7 @@ def __init__(self, config: RobertaConfig):
act_dropout=0,
)
self.encoder = nn.TransformerEncoder(encoder_layer, config.num_hidden_layers)
self.pooler = RobertaPooler(config.hidden_size)
self.pooler = RobertaPooler(config.hidden_size) if add_pooling_layer else None
self.apply(self.init_weights)

def get_input_embeddings(self):
Expand Down Expand Up @@ -316,7 +456,7 @@ def forward(
past_key_values_length = past_key_values[0][0].shape[2]
if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id).astype(self.pooler.dense.weight.dtype) * -1e4, axis=[1, 2]
(input_ids == self.pad_token_id).astype(paddle.get_default_dtype()) * -1e4, axis=[1, 2]
)
if past_key_values is not None:
batch_size = past_key_values[0][0].shape[0]
Expand Down Expand Up @@ -345,11 +485,11 @@ def forward(
)
if isinstance(encoder_outputs, type(embedding_output)):
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return (sequence_output, pooled_output)
else:
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

Expand All @@ -375,7 +515,7 @@ class RobertaForQuestionAnswering(RobertaPretrainedModel):
def __init__(self, config: RobertaConfig):
super(RobertaForQuestionAnswering, self).__init__(config)

self.roberta = RobertaModel(config)
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.classifier = nn.Linear(config.hidden_size, 2)
self.apply(self.init_weights)

Expand Down Expand Up @@ -483,6 +623,28 @@ def forward(
)


class RobertaClassificationHead(nn.Layer):
"""Head for sentence-level classification tasks."""

def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = paddle.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x


class RobertaForSequenceClassification(RobertaPretrainedModel):
r"""
Roberta Model with a linear layer on top of the output layer,
Expand All @@ -501,12 +663,12 @@ class RobertaForSequenceClassification(RobertaPretrainedModel):

def __init__(self, config: RobertaConfig):
super(RobertaForSequenceClassification, self).__init__(config)
self.roberta = RobertaModel(config)
self.roberta = RobertaModel(config, add_pooling_layer=False)

self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.classifier = RobertaClassificationHead(config)
self.apply(self.init_weights)

def forward(
Expand Down Expand Up @@ -573,10 +735,10 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
sequence_output = outputs[0]

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

loss = None
if labels is not None:
Expand Down Expand Up @@ -620,7 +782,7 @@ class RobertaForTokenClassification(RobertaPretrainedModel):
def __init__(self, config: RobertaConfig):
super(RobertaForTokenClassification, self).__init__(config)

self.roberta = RobertaModel(config)
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
Expand Down Expand Up @@ -836,6 +998,7 @@ def forward(
"""

num_choices = input_ids.shape[1]
print(input_ids.shape)

input_ids = input_ids.reshape((-1, input_ids.shape[-1])) if input_ids is not None else None
position_ids = position_ids.reshape((-1, position_ids.shape[-1])) if position_ids is not None else None
Expand Down Expand Up @@ -886,7 +1049,7 @@ class RobertaForMaskedLM(RobertaPretrainedModel):
def __init__(self, config: RobertaConfig):
super(RobertaForMaskedLM, self).__init__(config)

self.roberta = RobertaModel(config)
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.lm_head = RobertaLMHead(config)

self.apply(self.init_weights)
Expand Down Expand Up @@ -998,7 +1161,12 @@ def __init__(self, config: RobertaConfig):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)

tensor = paddle.zeros((config.vocab_size,))
self.bias = paddle.create_parameter(
shape=tensor.shape, dtype=tensor.dtype, default_initializer=nn.initializer.Assign(tensor)
)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.decoder.bias = self.bias

def forward(self, features, **kwargs):
x = self.dense(features)
Expand All @@ -1023,7 +1191,7 @@ class RobertaForCausalLM(RobertaPretrainedModel):

def __init__(self, config: RobertaConfig):
super().__init__(config)
self.roberta = RobertaModel(config)
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.lm_head = RobertaLMHead(config)
self.apply(self.init_weights)

Expand Down
Loading

0 comments on commit c9734ff

Please sign in to comment.