From 4e52188433efd3977a15a6823fcd14d150e6b3a7 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Tue, 6 Nov 2018 17:47:03 +0100 Subject: [PATCH 1/7] bert weight loading from tf --- convert_tf_checkpoint_to_pytorch.py | 56 ++++++++++++----------- modeling.py | 2 +- tests/mytest.py | 71 +++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 27 deletions(-) create mode 100644 tests/mytest.py diff --git a/convert_tf_checkpoint_to_pytorch.py b/convert_tf_checkpoint_to_pytorch.py index dfcdbee42d5091..408951a9364d68 100644 --- a/convert_tf_checkpoint_to_pytorch.py +++ b/convert_tf_checkpoint_to_pytorch.py @@ -26,35 +26,14 @@ from modeling import BertConfig, BertModel -parser = argparse.ArgumentParser() -## Required parameters -parser.add_argument("--tf_checkpoint_path", - default = None, - type = str, - required = True, - help = "Path the TensorFlow checkpoint path.") -parser.add_argument("--bert_config_file", - default = None, - type = str, - required = True, - help = "The config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture.") -parser.add_argument("--pytorch_dump_path", - default = None, - type = str, - required = True, - help = "Path to the output PyTorch model.") - -args = parser.parse_args() - -def convert(): +def convert(config_path, ckpt_path, out_path=None): # Initialise PyTorch model - config = BertConfig.from_json_file(args.bert_config_file) + config = BertConfig.from_json_file(config_path) model = BertModel(config) # Load weights from TF model - path = args.tf_checkpoint_path + path = ckpt_path print("Converting TensorFlow checkpoint from {}".format(path)) init_vars = tf.train.list_variables(path) @@ -99,7 +78,32 @@ def convert(): pointer.data = torch.from_numpy(array) # Save pytorch-model - torch.save(model.state_dict(), args.pytorch_dump_path) + if out_path is not None: + torch.save(model.state_dict(), out_path) + return model + if __name__ == "__main__": - convert() + parser = argparse.ArgumentParser() + + ## Required parameters + parser.add_argument("--tf_checkpoint_path", + default=None, + type=str, + required=True, + help="Path the TensorFlow checkpoint path.") + parser.add_argument("--bert_config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.") + parser.add_argument("--pytorch_dump_path", + default=None, + type=str, + required=False, + help="Path to the output PyTorch model.") + + args = parser.parse_args() + print(args) + convert(args.bert_config_file, args.tf_checkpoint_path, args.pytorch_dump_path) diff --git a/modeling.py b/modeling.py index c467e8266efa82..4cbb99f2fab6dc 100644 --- a/modeling.py +++ b/modeling.py @@ -355,7 +355,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) sequence_output = all_encoder_layers[-1] pooled_output = self.pooler(sequence_output) - return all_encoder_layers, pooled_output + return [embedding_output] + all_encoder_layers, pooled_output class BertForSequenceClassification(nn.Module): """BERT model for classification. diff --git a/tests/mytest.py b/tests/mytest.py new file mode 100644 index 00000000000000..2b2dadecda9545 --- /dev/null +++ b/tests/mytest.py @@ -0,0 +1,71 @@ +import unittest +import json +import random + +import torch +import numpy as np + +import modeling +import convert_tf_checkpoint_to_pytorch + +import grouch + + +class MyTest(unittest.TestCase): + def test_loading_and_running(self): + bertpath = "../../grouch/data/bert/bert-base/" + configpath = bertpath + "bert_config.json" + ckptpath = bertpath + "bert_model.ckpt" + m = convert_tf_checkpoint_to_pytorch.convert(configpath, ckptpath) + m.eval() + # print(m) + + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + all_y, pool_y = m(input_ids, token_type_ids, input_mask) + print(pool_y.shape) + # np.save("_bert_ref_pool_out.npy", pool_y.detach().numpy()) + # np.save("_bert_ref_all_out.npy", torch.stack(all_y, 0).detach().numpy()) + + config = grouch.TransformerBERT.load_config(configpath) + gm = grouch.TransformerBERT.init_from_config(config) + gm.load_weights_from_tf_checkpoint(ckptpath) + gm.eval() + + g_all_y, g_pool_y = gm(input_ids, token_type_ids, input_mask) + print(g_pool_y.shape) + + # check embeddings + # print(m.embeddings) + # print(gm.emb) + # hugging_emb = m.embeddings(input_ids, token_type_ids) + # grouch_emb = gm.emb(input_ids, token_type_ids) + + print((all_y[0] - g_all_y[0]).norm()) + # print(all_y[0][:, :, :10] - g_all_y[0][:, :, :10]) + self.assertTrue(np.allclose(all_y[0].detach().numpy(), g_all_y[0].detach().numpy(), atol=1e-7)) + print("embeddings good") + + print(m.encoder.layer[0]) + print(gm.encoder.layers[0]) + print("norm of diff at layer 1", (all_y[1] - g_all_y[1]).norm()) + # print(all_y[1][:, :, :10] - g_all_y[1][:, :, :10]) + self.assertTrue(np.allclose(all_y[1].detach().numpy(), g_all_y[1].detach().numpy(), atol=1e-6)) + + # hugging_layer = m.encoder.layer[0] + # grouch_layer = gm.encoder.layers[0] + # print("comparing weights") + # print((hugging_layer.attention.self.query.weight - grouch_layer.slf_attn.q_proj.weight).norm()) + # print((hugging_layer.attention.self.query.bias - grouch_layer.slf_attn.q_proj.bias).norm()) + # print((hugging_layer.attention.self.key.weight - grouch_layer.slf_attn.k_proj.weight).norm()) + # print((hugging_layer.attention.self.key.bias - grouch_layer.slf_attn.k_proj.bias).norm()) + # print((hugging_layer.attention.self.value.weight - grouch_layer.slf_attn.v_proj.weight).norm()) + # print((hugging_layer.attention.self.value.bias - grouch_layer.slf_attn.v_proj.bias).norm()) + # print((hugging_layer.attention.output.dense.weight - grouch_layer.slf_attn.vw_proj.weight).norm()) + # print((hugging_layer.attention.output.dense.bias - grouch_layer.slf_attn.vw_proj.bias).norm()) + + print("norm of diff at last layer", (all_y[-1] - g_all_y[-1]).norm()) + # print(all_y[-1][:, :, :10] - g_all_y[-1][:, :, :10]) + self.assertTrue(np.allclose(all_y[-1].detach().numpy(), g_all_y[-1].detach().numpy(), atol=1e-4)) \ No newline at end of file From bd91ae654faf8bc45eb68b13668d2013df4ffb9c Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Tue, 6 Nov 2018 18:21:44 +0100 Subject: [PATCH 2/7] moved bert to qelos-util --- hf_bert/__init__.py | 0 modeling.py | 11 +++++-- tests/mytest.py | 71 --------------------------------------------- 3 files changed, 8 insertions(+), 74 deletions(-) create mode 100644 hf_bert/__init__.py delete mode 100644 tests/mytest.py diff --git a/hf_bert/__init__.py b/hf_bert/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/modeling.py b/modeling.py index 4cbb99f2fab6dc..dd43c9c46aaea2 100644 --- a/modeling.py +++ b/modeling.py @@ -34,6 +34,10 @@ def gelu(x): return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) +def swish(x): + return x * torch.sigmoid(x) + + class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. """ @@ -60,7 +64,7 @@ def __init__(self, intermediate_size: The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. + encoder and pooler. If string, "gelu", "relu" and "swish" supported. hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob: The dropout ratio for the attention @@ -237,7 +241,8 @@ class BERTIntermediate(nn.Module): def __init__(self, config): super(BERTIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - self.intermediate_act_fn = gelu + act2fn = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish} + self.intermediate_act_fn = act2fn[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -355,7 +360,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) sequence_output = all_encoder_layers[-1] pooled_output = self.pooler(sequence_output) - return [embedding_output] + all_encoder_layers, pooled_output + return all_encoder_layers, pooled_output class BertForSequenceClassification(nn.Module): """BERT model for classification. diff --git a/tests/mytest.py b/tests/mytest.py deleted file mode 100644 index 2b2dadecda9545..00000000000000 --- a/tests/mytest.py +++ /dev/null @@ -1,71 +0,0 @@ -import unittest -import json -import random - -import torch -import numpy as np - -import modeling -import convert_tf_checkpoint_to_pytorch - -import grouch - - -class MyTest(unittest.TestCase): - def test_loading_and_running(self): - bertpath = "../../grouch/data/bert/bert-base/" - configpath = bertpath + "bert_config.json" - ckptpath = bertpath + "bert_model.ckpt" - m = convert_tf_checkpoint_to_pytorch.convert(configpath, ckptpath) - m.eval() - # print(m) - - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - all_y, pool_y = m(input_ids, token_type_ids, input_mask) - print(pool_y.shape) - # np.save("_bert_ref_pool_out.npy", pool_y.detach().numpy()) - # np.save("_bert_ref_all_out.npy", torch.stack(all_y, 0).detach().numpy()) - - config = grouch.TransformerBERT.load_config(configpath) - gm = grouch.TransformerBERT.init_from_config(config) - gm.load_weights_from_tf_checkpoint(ckptpath) - gm.eval() - - g_all_y, g_pool_y = gm(input_ids, token_type_ids, input_mask) - print(g_pool_y.shape) - - # check embeddings - # print(m.embeddings) - # print(gm.emb) - # hugging_emb = m.embeddings(input_ids, token_type_ids) - # grouch_emb = gm.emb(input_ids, token_type_ids) - - print((all_y[0] - g_all_y[0]).norm()) - # print(all_y[0][:, :, :10] - g_all_y[0][:, :, :10]) - self.assertTrue(np.allclose(all_y[0].detach().numpy(), g_all_y[0].detach().numpy(), atol=1e-7)) - print("embeddings good") - - print(m.encoder.layer[0]) - print(gm.encoder.layers[0]) - print("norm of diff at layer 1", (all_y[1] - g_all_y[1]).norm()) - # print(all_y[1][:, :, :10] - g_all_y[1][:, :, :10]) - self.assertTrue(np.allclose(all_y[1].detach().numpy(), g_all_y[1].detach().numpy(), atol=1e-6)) - - # hugging_layer = m.encoder.layer[0] - # grouch_layer = gm.encoder.layers[0] - # print("comparing weights") - # print((hugging_layer.attention.self.query.weight - grouch_layer.slf_attn.q_proj.weight).norm()) - # print((hugging_layer.attention.self.query.bias - grouch_layer.slf_attn.q_proj.bias).norm()) - # print((hugging_layer.attention.self.key.weight - grouch_layer.slf_attn.k_proj.weight).norm()) - # print((hugging_layer.attention.self.key.bias - grouch_layer.slf_attn.k_proj.bias).norm()) - # print((hugging_layer.attention.self.value.weight - grouch_layer.slf_attn.v_proj.weight).norm()) - # print((hugging_layer.attention.self.value.bias - grouch_layer.slf_attn.v_proj.bias).norm()) - # print((hugging_layer.attention.output.dense.weight - grouch_layer.slf_attn.vw_proj.weight).norm()) - # print((hugging_layer.attention.output.dense.bias - grouch_layer.slf_attn.vw_proj.bias).norm()) - - print("norm of diff at last layer", (all_y[-1] - g_all_y[-1]).norm()) - # print(all_y[-1][:, :, :10] - g_all_y[-1][:, :, :10]) - self.assertTrue(np.allclose(all_y[-1].detach().numpy(), g_all_y[-1].detach().numpy(), atol=1e-4)) \ No newline at end of file From fa0c5a2ea1da8ce9049a9e6f12a712b7c58a7119 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Tue, 13 Nov 2018 16:24:53 +0100 Subject: [PATCH 3/7] clean up pr --- convert_tf_checkpoint_to_pytorch.py | 68 ++++++++++++----------------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/convert_tf_checkpoint_to_pytorch.py b/convert_tf_checkpoint_to_pytorch.py index d4d47a3bd6372e..dfcdbee42d5091 100755 --- a/convert_tf_checkpoint_to_pytorch.py +++ b/convert_tf_checkpoint_to_pytorch.py @@ -26,14 +26,35 @@ from modeling import BertConfig, BertModel +parser = argparse.ArgumentParser() -def convert(config_path, ckpt_path, out_path=None): +## Required parameters +parser.add_argument("--tf_checkpoint_path", + default = None, + type = str, + required = True, + help = "Path the TensorFlow checkpoint path.") +parser.add_argument("--bert_config_file", + default = None, + type = str, + required = True, + help = "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.") +parser.add_argument("--pytorch_dump_path", + default = None, + type = str, + required = True, + help = "Path to the output PyTorch model.") + +args = parser.parse_args() + +def convert(): # Initialise PyTorch model - config = BertConfig.from_json_file(config_path) + config = BertConfig.from_json_file(args.bert_config_file) model = BertModel(config) # Load weights from TF model - path = ckpt_path + path = args.tf_checkpoint_path print("Converting TensorFlow checkpoint from {}".format(path)) init_vars = tf.train.list_variables(path) @@ -47,17 +68,11 @@ def convert(config_path, ckpt_path, out_path=None): arrays.append(array) for name, array in zip(names, arrays): - if not name.startswith("bert"): - print("Skipping {}".format(name)) - continue - else: - name = name.replace("bert/", "") # skip "bert/" + name = name[5:] # skip "bert/" print("Loading {}".format(name)) name = name.split('/') - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m": - print("Skipping {}".format("/".join(name))) + if name[0] in ['redictions', 'eq_relationship']: + print("Skipping") continue pointer = model for m_name in name: @@ -84,32 +99,7 @@ def convert(config_path, ckpt_path, out_path=None): pointer.data = torch.from_numpy(array) # Save pytorch-model - if out_path is not None: - torch.save(model.state_dict(), out_path) - return model - + torch.save(model.state_dict(), args.pytorch_dump_path) if __name__ == "__main__": - parser = argparse.ArgumentParser() - - ## Required parameters - parser.add_argument("--tf_checkpoint_path", - default=None, - type=str, - required=True, - help="Path the TensorFlow checkpoint path.") - parser.add_argument("--bert_config_file", - default=None, - type=str, - required=True, - help="The config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture.") - parser.add_argument("--pytorch_dump_path", - default=None, - type=str, - required=False, - help="Path to the output PyTorch model.") - - args = parser.parse_args() - print(args) - convert(args.bert_config_file, args.tf_checkpoint_path, args.pytorch_dump_path) + convert() From 7ba83730c48f80e932450da86ec601131d0f3679 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Tue, 13 Nov 2018 16:31:20 +0100 Subject: [PATCH 4/7] clean up pr --- convert_tf_checkpoint_to_pytorch.py | 12 +++++++++--- modeling.py | 8 ++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/convert_tf_checkpoint_to_pytorch.py b/convert_tf_checkpoint_to_pytorch.py index dfcdbee42d5091..eeebb3728ee8a6 100755 --- a/convert_tf_checkpoint_to_pytorch.py +++ b/convert_tf_checkpoint_to_pytorch.py @@ -68,11 +68,17 @@ def convert(): arrays.append(array) for name, array in zip(names, arrays): - name = name[5:] # skip "bert/" + if not name.startswith("bert"): + print("Skipping {}".format(name)) + continue + else: + name = name.replace("bert/", "") # skip "bert/" print("Loading {}".format(name)) name = name.split('/') - if name[0] in ['redictions', 'eq_relationship']: - print("Skipping") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m": + print("Skipping {}".format("/".join(name))) continue pointer = model for m_name in name: diff --git a/modeling.py b/modeling.py index 9874b3d5dfee2a..3b3f198c922de1 100644 --- a/modeling.py +++ b/modeling.py @@ -26,6 +26,10 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss + +ACT2FN = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish} + + def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): @@ -241,8 +245,8 @@ class BERTIntermediate(nn.Module): def __init__(self, config): super(BERTIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - act2fn = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish} - self.intermediate_act_fn = act2fn[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) From d64db6dfb94b302da876c03c989acf34deaa4ed4 Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Tue, 13 Nov 2018 16:41:01 +0100 Subject: [PATCH 5/7] clean up pr --- modeling.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modeling.py b/modeling.py index 3b3f198c922de1..53243e5eb435e2 100644 --- a/modeling.py +++ b/modeling.py @@ -25,10 +25,7 @@ import torch import torch.nn as nn from torch.nn import CrossEntropyLoss - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish} - +from six import string_types def gelu(x): """Implementation of the gelu activation function. @@ -42,6 +39,9 @@ def swish(x): return x * torch.sigmoid(x) +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. """ @@ -68,7 +68,7 @@ def __init__(self, intermediate_size: The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" supported. + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob: The dropout ratio for the attention @@ -246,7 +246,7 @@ def __init__(self, config): super(BERTIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.intermediate_act_fn = ACT2FN[config.hidden_act] \ - if isinstance(config.hidden_act, str) else config.hidden_act + if isinstance(config.hidden_act, string_types) else config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) From 3d4c7a6f5d8be320108c83d3a4a76105f6dd7071 Mon Sep 17 00:00:00 2001 From: Denis Date: Tue, 13 Nov 2018 16:48:43 +0100 Subject: [PATCH 6/7] Delete __init__.py --- hf_bert/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 hf_bert/__init__.py diff --git a/hf_bert/__init__.py b/hf_bert/__init__.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000 From 9f3cd27187dd107b5e31111b47ce50a5bcad074f Mon Sep 17 00:00:00 2001 From: lukovnikov Date: Tue, 13 Nov 2018 16:48:59 +0100 Subject: [PATCH 7/7] clean up pr --- hf_bert/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 hf_bert/__init__.py diff --git a/hf_bert/__init__.py b/hf_bert/__init__.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000