Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaddlePaddle Hackathon 56 提交 #1088

Merged
merged 25 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions community/junnyu/distilgpt2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# 详细介绍
# DistilGPT2
**介绍**: DistilGPT2 英语语言模型使用 OpenWebTextCorpus(OpenAI 的 WebText 数据集),使用 GPT2 的最小版本的进行了预训练。 该模型有 6 层、768 个维度和 12 个头,总计 82M 参数(相比之下 GPT2 的参数为 124M)。 平均而言,DistilGPT2 比 GPT2 快两倍。
在 WikiText-103 基准测试中,GPT2 在测试集上的困惑度为 16.3,而 DistilGPT2 的困惑度为 21.1(在训练集上进行微调后)。

**模型结构**: **`GPTLMHeadModel`**,GPT模型。

**适用下游任务**:**文本生成**。

# 使用示例

```python

import numpy as np

import paddle
from paddlenlp.transformers import GPTTokenizer, GPTLMHeadModel
from paddlenlp.utils.log import logger


class Demo:
def __init__(self, model_name_or_path="junnyu/distilgpt2", max_predict_len=32):

self.tokenizer = GPTTokenizer.from_pretrained(model_name_or_path)
logger.info("Loading the model parameters, please wait...")
self.max_predict_len = max_predict_len
self.model = GPTLMHeadModel.from_pretrained(
model_name_or_path, eol_token_id=self.tokenizer.eol_token_id
)
self.model.eval()
logger.info("Model loaded.")

@paddle.no_grad()
def predict(self, text="My name is Teven and I am"):
ids = self.tokenizer(text)["input_ids"]
input_ids = paddle.to_tensor(np.array(ids).reshape(1, -1).astype("int64"))
out = self.model.generate(
input_ids,
max_length=self.max_predict_len,
repetition_penalty=1.2,
temperature=0,
)[0][0]
out = [int(x) for x in out.numpy().reshape([-1])]
print(text + self.tokenizer.convert_ids_to_string(out))

demo = Demo(model_name_or_path="junnyu/distilgpt2",max_predict_len=64)
demo.predict(text="My name is Teven and I am")

# My name is Teven and I am a member of the team.
# I have been playing with my friends since we were little, so it was nice to see them play together in our home town on Saturday night! We are very excited about this opportunity for us as well!!<|endoftext|>
```

# 权重来源

https://huggingface.co/distilgpt2
7 changes: 7 additions & 0 deletions community/junnyu/distilgpt2/files.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/distilgpt2/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/distilgpt2/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/distilgpt2/tokenizer_config.json",
"merges_file":"https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/distilgpt2/merges.txt",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/distilgpt2/vocab.json"
}
115 changes: 115 additions & 0 deletions community/junnyu/gpt_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddlenlp.transformers import GPTLMHeadModel as PDGPT2LMHeadModel, GPTTokenizer, BertTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as PTGPT2LMHeadModel
from paddlenlp.transformers import GPTForTokenClassification, GPTForSequenceClassification, GPTTokenizer
import paddle
import torch
import numpy as np

paddle.set_grad_enabled(False)
torch.set_grad_enabled(False)


def compare(a, b):
a = a.cpu().numpy()
b = b.cpu().numpy()
meandif = np.abs(a - b).mean()
maxdif = np.abs(a - b).max()
print("mean dif:", meandif)
print("max dif:", maxdif)


def compare_lm(path="junnyu/microsoft-DialoGPT-small"):
pdmodel = PDGPT2LMHeadModel.from_pretrained(path)
ptmodel = PTGPT2LMHeadModel.from_pretrained(path).cuda()
if "chinese" in path:
text = "欢迎使用paddlenlp!"
tokenizer = BertTokenizer.from_pretrained(path)
else:
text = "Welcome to paddlenlp!"
tokenizer = GPTTokenizer.from_pretrained(path)
pdmodel.eval()
ptmodel.eval()
pdinputs = {
k: paddle.to_tensor(
v, dtype="int64").unsqueeze(0)
for k, v in tokenizer(
text, return_token_type_ids=False).items()
}
ptinputs = {
k: torch.tensor(
v, dtype=torch.long).unsqueeze(0).cuda()
for k, v in tokenizer(
text, return_token_type_ids=False).items()
}

pd_logits = pdmodel(**pdinputs)

pt_logits = ptmodel(**ptinputs).logits

compare(pd_logits, pt_logits)


def test_GPTForTokenClassification():

tokenizer = GPTTokenizer.from_pretrained("junnyu/distilgpt2")
m = GPTForTokenClassification.from_pretrained("junnyu/distilgpt2")
inputs = tokenizer(
"Welcome to use PaddlePaddle and PaddleNLP!",
return_token_type_ids=False)
inputs = {
k: paddle.to_tensor(
[v], dtype="int64")
for (k, v) in inputs.items()
}
logits = m(**inputs)
print(logits.shape)


def test_GPTForSequenceClassification():
paddle.set_grad_enabled(False)
tokenizer = GPTTokenizer.from_pretrained("junnyu/distilgpt2")
m = GPTForSequenceClassification.from_pretrained("junnyu/distilgpt2")
inputs = tokenizer(
"Welcome to use PaddlePaddle and PaddleNLP!",
return_token_type_ids=False)
inputs = {
k: paddle.to_tensor(
[v], dtype="int64")
for (k, v) in inputs.items()
}
logits = m(**inputs)
print(logits.shape)


if __name__ == "__main__":
# compare_lm(
# path="junnyu/microsoft-DialoGPT-small")
# mean dif: 7.501994e-05
# max dif: 0.00036621094
# compare_lm(
# path="junnyu/distilgpt2")
# mean dif: 7.249901e-06
# max dif: 5.340576e-05
# compare_lm(
# path="junnyu/uer-gpt2-chinese-poem")
# mean dif: 1.0497178e-06
# max dif: 1.335144e-05

# test_GPTForTokenClassification()
# [1, 13, 2]
test_GPTForSequenceClassification()
# [1, 2]
99 changes: 99 additions & 0 deletions community/junnyu/gpt_convert_huggingface2paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
import argparse

huggingface_to_paddle = {
"transformer.wte.weight": "gpt.embeddings.word_embeddings.weight",
"transformer.wpe.weight": "gpt.embeddings.position_embeddings.weight",
"transformer.h.": "gpt.decoder.layers.",
".attn.c_proj.": ".self_attn.out_proj.",
".ln_1.": ".norm1.",
".mlp.c_fc.": ".linear1.",
".mlp.c_proj.": ".linear2.",
".ln_2.": ".norm2.",
"transformer.ln_f.": "gpt.decoder.norm.",
"lm_head.weight": "lm_head.decoder_weight"
}

skip_weights = [".attn.bias", "lm_head.weight"]
dont_transpose = [
".wte.weight", ".wpe.weight", ".ln_", ".mlp.c_proj.", ".mlp.c_fc.",
".attn.c_proj.", "lm_head.weight"
]


# 注意,huggingface使用的Conv1D的weight和paddle.nn.Linear中的weight形状一致,因此不需要转置。
# 如果使用了torch.nn.Linear那么就需要转置了!
def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path,
paddle_dump_path):
import torch
import paddle
pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
paddle_state_dict = OrderedDict()
for k, v in pytorch_state_dict.items():
is_transpose = False
if k in skip_weights:
continue
# c_attn
if ".attn.c_attn." in k:
query_value_key = v.chunk(chunks=3, dim=-1)
for cross_value, new_name in zip(query_value_key, [
".self_attn.q_proj.", ".self_attn.k_proj.",
".self_attn.v_proj."
]):
oldk = k
newk = k.replace("transformer.h.",
"gpt.decoder.layers.").replace(".attn.c_attn.",
new_name)
paddle_state_dict[newk] = cross_value.data.numpy().astype(
"float32")
print(
f"Converting: {oldk} => {newk} | is_transpose {is_transpose}"
)
continue

if k[-7:] == ".weight":
if not any([w in k for w in dont_transpose]):
if v.ndim == 2:
v = v.transpose(0, 1)
is_transpose = True
oldk = k
for huggingface_name, paddle_name in huggingface_to_paddle.items():
k = k.replace(huggingface_name, paddle_name)

print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}")
paddle_state_dict[k] = v.data.numpy().astype("float32")

paddle.save(paddle_state_dict, paddle_dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pytorch_checkpoint_path",
default=r"community\junnyu\microsoft-DialoGPT-small\pytorch_model.bin",
type=str,
required=False,
help="Path to the Pytorch checkpoint path.")
parser.add_argument(
"--paddle_dump_path",
default=r"community\junnyu\microsoft-DialoGPT-small\model_state.pdparams",
type=str,
required=False,
help="Path to the output Paddle model.")
args = parser.parse_args()
convert_pytorch_checkpoint_to_paddle(args.pytorch_checkpoint_path,
args.paddle_dump_path)
67 changes: 67 additions & 0 deletions community/junnyu/microsoft-DialoGPT-large/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# 详细介绍
# microsoft-DialoGPT-small
**介绍**: 最先进的大规模预训练响应生成模型 (DialoGPT)
DialoGPT 是一种用于多轮对话的 SOTA 大规模预训练对话响应生成模型。 人类评估结果表明,DialoGPT 生成的响应与单轮对话图灵测试下的人类响应质量相当。 该模型是在来自 Reddit 讨论的 147M 多轮对话上训练的。

**模型结构**: **`GPTLMHeadModel`**,GPT模型。

**适用下游任务**:**文本生成**。

# 使用示例

```python
import numpy as np

import paddle
from paddlenlp.transformers import GPTTokenizer, GPTLMHeadModel
from paddlenlp.utils.log import logger


class Demo:
def __init__(self, model_name_or_path="junnyu/microsoft-DialoGPT-small", max_predict_len=32):

self.tokenizer = GPTTokenizer.from_pretrained(model_name_or_path)
logger.info("Loading the model parameters, please wait...")
self.max_predict_len = max_predict_len
self.model = GPTLMHeadModel.from_pretrained(
model_name_or_path, eol_token_id=self.tokenizer.eol_token_id
)
self.model.eval()

logger.info("Model loaded.")

@paddle.no_grad()
def predict(self):
# Let's chat for 5 lines
for step in range(5):
# encode the new user input, add the eos_token and return a tensor in Pytorch
ids = self.tokenizer(input(">> User:"))["input_ids"] + [self.tokenizer.eos_token_id]
new_user_input_ids = paddle.to_tensor(np.array(ids).reshape(1, -1).astype("int64"))

# append the new user input tokens to the chat history
bot_input_ids = paddle.concat([chat_history_ids, new_user_input_ids], axis=-1) if step > 0 else new_user_input_ids


# generated a response while limiting the total chat history to 1000 tokens,
chat_history_ids = self.model.generate(bot_input_ids, max_length=self.max_predict_len, pad_token_id=self.tokenizer.eos_token_id,decode_strategy="sampling",top_k=5,)[0]

# pretty print last ouput tokens from bot
print("DialoGPT: {}".format(self.tokenizer.convert_ids_to_string(chat_history_ids[0].tolist()).replace("<|endoftext|>","")))
chat_history_ids = paddle.concat([new_user_input_ids, chat_history_ids], axis=-1)

demo = Demo(model_name_or_path="junnyu/microsoft-DialoGPT-large")
demo.predict()

# >> User: Does money buy happiness?
# DialoGPT: No , but it can buy you a better life .
# >> User: What is the best way to buy happiness ?
# DialoGPT: A job , money , and a better life .
# >> User: This is so difficult !
# DialoGPT: Just get a job , money , and a better life . Then you can buy happiness .
# >> User: Oh, thank you!
# DialoGPT: No problem , friend .
```

# 权重来源

https://huggingface.co/microsoft/DialoGPT-large
7 changes: 7 additions & 0 deletions community/junnyu/microsoft-DialoGPT-large/files.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/microsoft-DialoGPT-large/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/microsoft-DialoGPT-large/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/microsoft-DialoGPT-large/tokenizer_config.json",
"merges_file":"https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/microsoft-DialoGPT-large/merges.txt",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/microsoft-DialoGPT-large/vocab.json"
}
Loading