Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaddlePaddle Hackathon 54 提交 #1086

Merged
merged 21 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions community/junnyu/electra_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddlenlp.transformers import ElectraDiscriminator, ElectraForMaskedLM, ElectraTokenizer
from transformers.models.electra.modeling_electra import ElectraForPreTraining, ElectraForMaskedLM as PTElectraForMaskedLM
import paddle
import torch
import numpy as np


def compare(a, b):
a = a.cpu().numpy()
b = b.cpu().numpy()
meandif = np.abs(a - b).mean()
maxdif = np.abs(a - b).max()
print("mean dif:", meandif)
print("max dif:", maxdif)


def compare_discriminator(
path="MODEL/hfl-chinese-electra-180g-base-discriminator"):
pdmodel = ElectraDiscriminator.from_pretrained(path)
ptmodel = ElectraForPreTraining.from_pretrained(path).cuda()
tokenizer = ElectraTokenizer.from_pretrained(path)
pdmodel.eval()
ptmodel.eval()
text = "欢迎使用paddlenlp!"
pdinputs = {
k: paddle.to_tensor(
v, dtype="int64").unsqueeze(0)
for k, v in tokenizer(text).items()
}
ptinputs = {
k: torch.tensor(
v, dtype=torch.long).unsqueeze(0).cuda()
for k, v in tokenizer(text).items()
}
with paddle.no_grad():
pd_logits = pdmodel(**pdinputs)

with torch.no_grad():
pt_logits = ptmodel(**ptinputs).logits

compare(pd_logits, pt_logits)


def compare_generator(path="MODEL/hfl-chinese-legal-electra-small-generator"):
pdmodel = ElectraForMaskedLM.from_pretrained(path)
ptmodel = PTElectraForMaskedLM.from_pretrained(path).cuda()
tokenizer = ElectraTokenizer.from_pretrained(path)
pdmodel.eval()
ptmodel.eval()
text = "欢迎使用paddlenlp!"
pdinputs = {
k: paddle.to_tensor(
v, dtype="int64").unsqueeze(0)
for k, v in tokenizer(text).items()
}
ptinputs = {
k: torch.tensor(
v, dtype=torch.long).unsqueeze(0).cuda()
for k, v in tokenizer(text).items()
}
with paddle.no_grad():
pd_prediction_scores = pdmodel(**pdinputs)

with torch.no_grad():
pt_logits = ptmodel(**ptinputs).logits

compare(pd_prediction_scores, pt_logits)


if __name__ == "__main__":
compare_discriminator(
path="MODEL/hfl-chinese-electra-180g-base-discriminator")
# # mean dif: 3.1698835e-06
# # max dif: 1.335144e-05
compare_discriminator(
path="MODEL/hfl-chinese-electra-180g-small-ex-discriminator")
# mean dif: 3.7930229e-06
# max dif: 1.04904175e-05
compare_generator(path="MODEL/hfl-chinese-legal-electra-small-generator")
# mean dif: 6.6151397e-06
# max dif: 9.346008e-05
79 changes: 79 additions & 0 deletions community/junnyu/electra_convert_huggingface2paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
import argparse

huggingface_to_paddle = {
"embeddings.LayerNorm": "embeddings.layer_norm",
"encoder.layer": "encoder.layers",
"attention.self.query.": "self_attn.q_proj.",
"attention.self.key.": "self_attn.k_proj.",
"attention.self.value.": "self_attn.v_proj.",
"attention.output.dense.": "self_attn.out_proj.",
"intermediate.dense": "linear1",
"output.dense": "linear2",
"attention.output.LayerNorm": "norm1",
"output.LayerNorm": "norm2",
"generator_predictions.LayerNorm": "generator_predictions.layer_norm",
"generator_lm_head.bias": "generator_lm_head_bias",
}

skip_weights = ["electra.embeddings.position_ids"]
dont_transpose = ["_embeddings.weight", "LayerNorm."]


def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path,
paddle_dump_path):
import torch
import paddle
pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
paddle_state_dict = OrderedDict()
for k, v in pytorch_state_dict.items():
if k == "generator_lm_head.weight": continue
is_transpose = False
if k in skip_weights:
continue
if k[-7:] == ".weight":
if not any([w in k for w in dont_transpose]):
if v.ndim == 2:
v = v.transpose(0, 1)
is_transpose = True
oldk = k
for huggingface_name, paddle_name in huggingface_to_paddle.items():
k = k.replace(huggingface_name, paddle_name)

print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}")
paddle_state_dict[k] = v.data.numpy()

paddle.save(paddle_state_dict, paddle_dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pytorch_checkpoint_path",
default=r"MODEL\hfl-chinese-electra-180g-base-discriminator\pytorch_model.bin",
type=str,
required=False,
help="Path to the Pytorch checkpoint path.")
parser.add_argument(
"--paddle_dump_path",
default=r"MODEL\hfl-chinese-electra-180g-base-discriminator\model_state.pdparams",
type=str,
required=False,
help="Path to the output Paddle model.")
args = parser.parse_args()
convert_pytorch_checkpoint_to_paddle(args.pytorch_checkpoint_path,
args.paddle_dump_path)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# 详细介绍
# Chinese ELECTRA
谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。
这个项目依赖于官方ELECTRA代码: https://github.com/google-research/electra
该模型是base版本的discriminator,并且在180G的中文数据上进行训练。

# 使用示例

```python
from paddlenlp.transformers import ElectraDiscriminator,ElectraTokenizer

path = "hfl-chinese-electra-180g-base-discriminator"
JunnYu marked this conversation as resolved.
Show resolved Hide resolved
model = ElectraDiscriminator.from_pretrained(path)
tokenizer = ElectraTokenizer.from_pretrained(path)
model.eval()

text = "欢迎使用paddlenlp!"
inputs = {
k: paddle.to_tensor(
v, dtype="int64").unsqueeze(0)
for k, v in tokenizer(text).items()
}

with paddle.no_grad():
logits = pdmodel(**inputs)

print(logits.shape)

```

# 权重来源

https://huggingface.co/hfl/chinese-electra-180g-base-discriminator
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/vocab.txt"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# 详细介绍
# Chinese ELECTRA
谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。
这个项目依赖于官方ELECTRA代码: https://github.com/google-research/electra
该模型是small版本的discriminator,并且在180G的中文数据上进行训练。

# 使用示例

```python
from paddlenlp.transformers import ElectraDiscriminator,ElectraTokenizer

path = "hfl-chinese-electra-180g-small-ex-discriminator"
JunnYu marked this conversation as resolved.
Show resolved Hide resolved
model = ElectraDiscriminator.from_pretrained(path)
tokenizer = ElectraTokenizer.from_pretrained(path)
model.eval()

text = "欢迎使用paddlenlp!"
inputs = {
k: paddle.to_tensor(
v, dtype="int64").unsqueeze(0)
for k, v in tokenizer(text).items()
}

with paddle.no_grad():
logits = pdmodel(**inputs)

print(logits.shape)

```

# 权重来源

https://huggingface.co/hfl/chinese-electra-180g-small-ex-discriminator
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/vocab.txt"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# 详细介绍
# Chinese ELECTRA
谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。
这个项目依赖于官方ELECTRA代码: https://github.com/google-research/electra
该模型是small版本的generator,并且该模型专为法律领域而设计。

# 使用示例

```python
from paddlenlp.transformers import ElectraGenerator,ElectraTokenizer

path = "hfl-chinese-legal-electra-small-generator"
JunnYu marked this conversation as resolved.
Show resolved Hide resolved
model = ElectraGenerator.from_pretrained(path)
tokenizer = ElectraTokenizer.from_pretrained(path)
model.eval()

text = "欢迎使用paddlenlp!"
inputs = {
k: paddle.to_tensor(
v, dtype="int64").unsqueeze(0)
for k, v in tokenizer(text).items()
}
JunnYu marked this conversation as resolved.
Show resolved Hide resolved

with paddle.no_grad():
prediction_scores = pdmodel(**inputs)

print(prediction_scores.shape)

```

# 权重来源

https://huggingface.co/hfl/chinese-legal-electra-small-generator
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/vocab.txt"
}
1 change: 0 additions & 1 deletion paddlenlp/transformers/electra/README.md

This file was deleted.

Loading