Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for XLM-R XL and XXL models by modeling_xlm_roberta_xl.py #13727

Merged
merged 27 commits into from
Jan 29, 2022
Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
70f0f04
add xlm roberta xl
Soonhwan-Kwon Sep 24, 2021
b831d60
add convert xlm xl fairseq checkpoint to pytorch
Soonhwan-Kwon Sep 24, 2021
123aab4
fix init and documents for xlm-roberta-xl
Soonhwan-Kwon Sep 24, 2021
1514f67
fix indention
Soonhwan-Kwon Sep 24, 2021
57d72ca
add test for XLM-R xl,xxl
Soonhwan-Kwon Oct 13, 2021
bd19941
fix model hub name
Soonhwan-Kwon Nov 14, 2021
d2d2715
fix some stuff
patrickvonplaten Dec 31, 2021
6be7307
Merge branch 'master' of https://github.com/huggingface/transformers …
patrickvonplaten Dec 31, 2021
9b4203f
up
patrickvonplaten Dec 31, 2021
5fca25a
correct init
patrickvonplaten Dec 31, 2021
df499c7
fix more
patrickvonplaten Dec 31, 2021
6a9c09f
fix as suggestions
Soonhwan-Kwon Jan 6, 2022
83852d1
add torch_device
Soonhwan-Kwon Jan 6, 2022
21bcebb
fix default values of doc strings
Soonhwan-Kwon Jan 6, 2022
7b058be
fix leftovers
patrickvonplaten Jan 28, 2022
351ada4
Merge branch 'master' of https://github.com/huggingface/transformers …
patrickvonplaten Jan 28, 2022
c4af533
merge to master
patrickvonplaten Jan 28, 2022
864620b
up
patrickvonplaten Jan 28, 2022
a9b13b8
correct hub names
patrickvonplaten Jan 28, 2022
1525c94
fix docs
patrickvonplaten Jan 28, 2022
4762f20
fix model
patrickvonplaten Jan 28, 2022
316a750
up
patrickvonplaten Jan 28, 2022
f9ad5ff
finalize
patrickvonplaten Jan 28, 2022
b9b80f4
last fix
patrickvonplaten Jan 28, 2022
1d49d20
Apply suggestions from code review
patrickvonplaten Jan 29, 2022
ad3e260
add copied from
patrickvonplaten Jan 29, 2022
9142af2
make style
patrickvonplaten Jan 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add convert xlm xl fairseq checkpoint to pytorch
  • Loading branch information
Soonhwan-Kwon committed Sep 24, 2021
commit b831d6042b654b24cd3f1576598fb664f747fc22
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert RoBERTa checkpoint."""

import argparse
import pathlib

import fairseq
import torch
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version

from transformers import XLMRobertaXLConfig, XLMRobertaXLForMaskedLM, XLMRobertaXLForSequenceClassification
from transformers.models.bert.modeling_bert import (
BertIntermediate,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.utils import logging
from transformers.models.roberta.modeling_roberta import RobertaAttention

if version.parse(fairseq.__version__) < version.parse("1.0.0a"):
raise Exception("requires fairseq >= 1.0.0a")

logging.set_verbosity_info()
logger = logging.get_logger(__name__)

SAMPLE_TEXT = "Hello world! cécé herlolip"


def convert_xlm_roberta_xl_checkpoint_to_pytorch(
roberta_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool
):
"""
Copy/paste/tweak roberta's weights to our BERT structure.
"""
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
roberta.eval() # disable dropout
roberta_sent_encoder = roberta.model.encoder.sentence_encoder
config = XLMRobertaXLConfig(
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
hidden_size=roberta.cfg.model.encoder_embed_dim,
num_hidden_layers=roberta.cfg.model.encoder_layers,
num_attention_heads=roberta.cfg.model.encoder_attention_heads,
intermediate_size=roberta.cfg.model.encoder_ffn_embed_dim,
max_position_embeddings=514,
type_vocab_size=1,
layer_norm_eps=1e-5, # PyTorch default used in fairseq
)
if classification_head:
config.num_labels = roberta.model.classification_heads["mnli"].out_proj.weight.shape[0]

print("Our RoBERTa config:", config)

model = XLMRobertaXLForSequenceClassification(config) if classification_head else XLMRobertaXLForMaskedLM(config)
model.eval()

# Now let's copy all the weights.
# Embeddings
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
model.roberta.embeddings.token_type_embeddings.weight
) # just zero them out b/c RoBERTa doesn't use them.


model.roberta.encoder.LayerNorm.weight = roberta_sent_encoder.layer_norm.weight
model.roberta.encoder.LayerNorm.bias = roberta_sent_encoder.layer_norm.bias

for i in range(config.num_hidden_layers):
# Encoder: start of layer
layer: BertLayer = model.roberta.encoder.layer[i]
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]

attention: RobertaAttention = layer.attention
attention.self_attn_layer_norm.weight = roberta_layer.self_attn_layer_norm.weight
attention.self_attn_layer_norm.bias = roberta_layer.self_attn_layer_norm.bias

# self attention
self_attn: BertSelfAttention = layer.attention.self
assert (
roberta_layer.self_attn.k_proj.weight.data.shape
== roberta_layer.self_attn.q_proj.weight.data.shape
== roberta_layer.self_attn.v_proj.weight.data.shape
== torch.Size((config.hidden_size, config.hidden_size))
)

self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
self_attn.query.bias.data = roberta_layer.self_attn.q_proj.bias
self_attn.key.weight.data = roberta_layer.self_attn.k_proj.weight
self_attn.key.bias.data = roberta_layer.self_attn.k_proj.bias
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias

# self-attention output
self_output: BertSelfOutput = layer.attention.output
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias

# this one is final layer norm
layer.LayerNorm.weight = roberta_layer.final_layer_norm.weight
layer.LayerNorm.bias = roberta_layer.final_layer_norm.bias

# intermediate
intermediate: BertIntermediate = layer.intermediate
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias

# output
bert_output: BertOutput = layer.output
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias
# end of layer

if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
model.classifier.dense.bias = roberta.model.classification_heads["mnli"].dense.bias
model.classifier.out_proj.weight = roberta.model.classification_heads["mnli"].out_proj.weight
model.classifier.out_proj.bias = roberta.model.classification_heads["mnli"].out_proj.bias
else:
# LM Head
model.lm_head.dense.weight = roberta.model.encoder.lm_head.dense.weight
model.lm_head.dense.bias = roberta.model.encoder.lm_head.dense.bias
model.lm_head.layer_norm.weight = roberta.model.encoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = roberta.model.encoder.lm_head.layer_norm.bias
model.lm_head.decoder.weight = roberta.model.encoder.lm_head.weight
model.lm_head.decoder.bias = roberta.model.encoder.lm_head.bias

# Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1

our_output = model(input_ids)[0]
if classification_head:
their_output = roberta.model.classification_heads["mnli"](roberta.extract_features(input_ids))
else:
their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
success = torch.allclose(our_output, their_output, atol=1e-3)
print("Do both models output the same tensors?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")

pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--classification_head", action="store_true", help="Whether to convert a final classification head."
)
args = parser.parse_args()
convert_xlm_roberta_xl_checkpoint_to_pytorch(
args.roberta_checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
)