Skip to content

Commit c862e9c

Browse files
JunnYuyingyibiao
andauthored
PaddlePaddle Hackathon 52 提交 (#1085)
* update * update * add community/junnyu * update bert docs * rm tokenizer links * fix typo * update TestBertForMaskedLM * update * 删除多余的attention mask判断 * update readme * update docs * add * Revert "update docs" This reverts commit 79169c2. * update docs/model_zoo * update modelzoo rst * replace tab with space * fix typo Co-authored-by: yingyibiao <yyb0576@163.com>
1 parent 7d0d89d commit c862e9c

File tree

15 files changed

+2064
-454
lines changed

15 files changed

+2064
-454
lines changed

community/junnyu/bert_compare.py

Lines changed: 663 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import OrderedDict
16+
import argparse
17+
18+
huggingface_to_paddle = {
19+
"embeddings.LayerNorm": "embeddings.layer_norm",
20+
"encoder.layer": "encoder.layers",
21+
"attention.self.query": "self_attn.q_proj",
22+
"attention.self.key": "self_attn.k_proj",
23+
"attention.self.value": "self_attn.v_proj",
24+
"attention.output.dense": "self_attn.out_proj",
25+
"intermediate.dense": "linear1",
26+
"output.dense": "linear2",
27+
"attention.output.LayerNorm": "norm1",
28+
"output.LayerNorm": "norm2",
29+
"predictions.decoder.": "predictions.decoder_",
30+
"predictions.transform.dense": "predictions.transform",
31+
"predictions.transform.LayerNorm": "predictions.layer_norm",
32+
}
33+
34+
35+
def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path,
36+
paddle_dump_path):
37+
38+
import torch
39+
import paddle
40+
pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
41+
paddle_state_dict = OrderedDict()
42+
for k, v in pytorch_state_dict.items():
43+
is_transpose = False
44+
if k[-7:] == ".weight":
45+
if ".embeddings." not in k and ".LayerNorm." not in k:
46+
if v.ndim == 2:
47+
v = v.transpose(0, 1)
48+
is_transpose = True
49+
oldk = k
50+
for huggingface_name, paddle_name in huggingface_to_paddle.items():
51+
k = k.replace(huggingface_name, paddle_name)
52+
53+
if "bert." not in k and "cls." not in k and "classifier" not in k:
54+
k = "bert." + k
55+
56+
print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}")
57+
paddle_state_dict[k] = v.data.numpy()
58+
59+
paddle.save(paddle_state_dict, paddle_dump_path)
60+
61+
62+
if __name__ == "__main__":
63+
parser = argparse.ArgumentParser()
64+
parser.add_argument(
65+
"--pytorch_checkpoint_path",
66+
default="MODEL/ckiplab-bert-base-chinese-ws/pytorch_model.bin",
67+
type=str,
68+
required=False,
69+
help="Path to the Pytorch checkpoint path.")
70+
parser.add_argument(
71+
"--paddle_dump_path",
72+
default="MODEL/ckiplab-bert-base-chinese-ws/model_state.pdparams",
73+
type=str,
74+
required=False,
75+
help="Path to the output Paddle model.")
76+
args = parser.parse_args()
77+
convert_pytorch_checkpoint_to_paddle(args.pytorch_checkpoint_path,
78+
args.paddle_dump_path)
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# 详细介绍
2+
**介绍**:ckiplab-bert-base-chinese-ner 是一个带有token分类头的BERT模型,该模型已经在**命名实体识别任务**上进行了微调。
3+
4+
关于完整使用方法及其他信息,请参考 https://github.com/ckiplab/ckip-transformers
5+
6+
**模型结构****`BertForTokenClassification`**,带有token分类头的Bert模型。
7+
**适用下游任务****命名实体识别**,该权重已经在下游`NER`任务上进行了微调,因此可直接使用。
8+
9+
# 使用示例
10+
11+
```python
12+
import paddle
13+
import paddle.nn.functional as F
14+
from paddlenlp.transformers import BertForTokenClassification, BertTokenizer
15+
path = "junnyu/ckiplab-bert-base-chinese-ner"
16+
model = BertForTokenClassification.from_pretrained(path)
17+
model.eval()
18+
tokenizer = BertTokenizer.from_pretrained(path)
19+
text = "傅達仁今將執行安樂死,卻突然爆出自己20年前遭緯來體育台封殺,他不懂自己哪裡得罪到電視台。"
20+
tokenized_text = tokenizer.tokenize(text)
21+
inputs = {
22+
k: paddle.to_tensor(
23+
v, dtype="int64").unsqueeze(0)
24+
for k, v in tokenizer(text).items()
25+
}
26+
with paddle.no_grad():
27+
score = F.softmax(model(**inputs), axis=-1)
28+
id2label = {
29+
"0": "O",
30+
"1": "B-CARDINAL",
31+
"2": "B-DATE",
32+
"3": "B-EVENT",
33+
"4": "B-FAC",
34+
"5": "B-GPE",
35+
"6": "B-LANGUAGE",
36+
"7": "B-LAW",
37+
"8": "B-LOC",
38+
"9": "B-MONEY",
39+
"10": "B-NORP",
40+
"11": "B-ORDINAL",
41+
"12": "B-ORG",
42+
"13": "B-PERCENT",
43+
"14": "B-PERSON",
44+
"15": "B-PRODUCT",
45+
"16": "B-QUANTITY",
46+
"17": "B-TIME",
47+
"18": "B-WORK_OF_ART",
48+
"19": "I-CARDINAL",
49+
"20": "I-DATE",
50+
"21": "I-EVENT",
51+
"22": "I-FAC",
52+
"23": "I-GPE",
53+
"24": "I-LANGUAGE",
54+
"25": "I-LAW",
55+
"26": "I-LOC",
56+
"27": "I-MONEY",
57+
"28": "I-NORP",
58+
"29": "I-ORDINAL",
59+
"30": "I-ORG",
60+
"31": "I-PERCENT",
61+
"32": "I-PERSON",
62+
"33": "I-PRODUCT",
63+
"34": "I-QUANTITY",
64+
"35": "I-TIME",
65+
"36": "I-WORK_OF_ART",
66+
"37": "E-CARDINAL",
67+
"38": "E-DATE",
68+
"39": "E-EVENT",
69+
"40": "E-FAC",
70+
"41": "E-GPE",
71+
"42": "E-LANGUAGE",
72+
"43": "E-LAW",
73+
"44": "E-LOC",
74+
"45": "E-MONEY",
75+
"46": "E-NORP",
76+
"47": "E-ORDINAL",
77+
"48": "E-ORG",
78+
"49": "E-PERCENT",
79+
"50": "E-PERSON",
80+
"51": "E-PRODUCT",
81+
"52": "E-QUANTITY",
82+
"53": "E-TIME",
83+
"54": "E-WORK_OF_ART",
84+
"55": "S-CARDINAL",
85+
"56": "S-DATE",
86+
"57": "S-EVENT",
87+
"58": "S-FAC",
88+
"59": "S-GPE",
89+
"60": "S-LANGUAGE",
90+
"61": "S-LAW",
91+
"62": "S-LOC",
92+
"63": "S-MONEY",
93+
"64": "S-NORP",
94+
"65": "S-ORDINAL",
95+
"66": "S-ORG",
96+
"67": "S-PERCENT",
97+
"68": "S-PERSON",
98+
"69": "S-PRODUCT",
99+
"70": "S-QUANTITY",
100+
"71": "S-TIME",
101+
"72": "S-WORK_OF_ART"
102+
}
103+
for t, s in zip(tokenized_text, score[0][1:-1]):
104+
index = paddle.argmax(s).item()
105+
label = id2label[str(index)]
106+
print(f"{label} {t} score {s[index].item()}")
107+
108+
# B-PERSON 傅 score 0.9999995231628418
109+
# I-PERSON 達 score 0.9999994039535522
110+
# E-PERSON 仁 score 0.9999995231628418
111+
# B-DATE 今 score 0.9991734623908997
112+
# O 將 score 0.9852147698402405
113+
# O 執 score 1.0
114+
# O 行 score 0.9999998807907104
115+
# O 安 score 0.9999996423721313
116+
# O 樂 score 0.9999997615814209
117+
# O 死 score 0.9999997615814209
118+
# O , score 1.0
119+
# O 卻 score 1.0
120+
# O 突 score 1.0
121+
# O 然 score 1.0
122+
# O 爆 score 1.0
123+
# O 出 score 1.0
124+
# O 自 score 1.0
125+
# O 己 score 1.0
126+
# B-DATE 20 score 0.9999992847442627
127+
# E-DATE 年 score 0.9999892711639404
128+
# O 前 score 0.9999995231628418
129+
# O 遭 score 1.0
130+
# B-ORG 緯 score 0.9999990463256836
131+
# I-ORG 來 score 0.9999986886978149
132+
# I-ORG 體 score 0.999998927116394
133+
# I-ORG 育 score 0.9999985694885254
134+
# E-ORG 台 score 0.999998927116394
135+
# O 封 score 1.0
136+
# O 殺 score 1.0
137+
# O , score 1.0
138+
# O 他 score 1.0
139+
# O 不 score 1.0
140+
# O 懂 score 1.0
141+
# O 自 score 1.0
142+
# O 己 score 1.0
143+
# O 哪 score 1.0
144+
# O 裡 score 1.0
145+
# O 得 score 1.0
146+
# O 罪 score 1.0
147+
# O 到 score 1.0
148+
# O 電 score 1.0
149+
# O 視 score 1.0
150+
# O 台 score 1.0
151+
# O 。 score 0.9999960660934448
152+
153+
```
154+
155+
# 权重来源
156+
157+
https://huggingface.co/ckiplab/bert-base-chinese-ner
158+
这个项目提供了繁体中文版transformer模型(包含ALBERT、BERT、GPT2)及自然语言处理工具(包含分词、词性标注、命名实体识别)。
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/ckiplab-bert-base-chinese-ner/model_config.json",
3+
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/ckiplab-bert-base-chinese-ner/model_state.pdparams",
4+
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/ckiplab-bert-base-chinese-ner/tokenizer_config.json",
5+
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/ckiplab-bert-base-chinese-ner/vocab.txt"
6+
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# 详细介绍
2+
**介绍**:ckiplab-bert-base-chinese-pos 是一个带有token分类头的BERT模型,该模型已经在**词性标注任务**上进行了微调。
3+
4+
关于完整使用方法及其他信息,请参考 https://github.com/ckiplab/ckip-transformers
5+
6+
**模型结构****`BertForTokenClassification`**,带有token分类头的Bert模型。
7+
**适用下游任务****词性标注**,该权重已经在下游`POS`任务上进行了微调,因此可直接使用。
8+
9+
# 使用示例
10+
11+
```python
12+
import paddle
13+
import paddle.nn.functional as F
14+
from paddlenlp.transformers import BertForTokenClassification, BertTokenizer
15+
path = "junnyu/ckiplab-bert-base-chinese-pos"
16+
model = BertForTokenClassification.from_pretrained(path)
17+
model.eval()
18+
tokenizer = BertTokenizer.from_pretrained(path)
19+
text = "傅達仁今將執行安樂死,卻突然爆出自己20年前遭緯來體育台封殺,他不懂自己哪裡得罪到電視台。"
20+
tokenized_text = tokenizer.tokenize(text)
21+
inputs = {
22+
k: paddle.to_tensor(
23+
v, dtype="int64").unsqueeze(0)
24+
for k, v in tokenizer(text).items()
25+
}
26+
with paddle.no_grad():
27+
score = F.softmax(model(**inputs), axis=-1)
28+
id2label = {
29+
"0": "A",
30+
"1": "Caa",
31+
"2": "Cab",
32+
"3": "Cba",
33+
"4": "Cbb",
34+
"5": "D",
35+
"6": "Da",
36+
"7": "Dfa",
37+
"8": "Dfb",
38+
"9": "Di",
39+
"10": "Dk",
40+
"11": "DM",
41+
"12": "I",
42+
"13": "Na",
43+
"14": "Nb",
44+
"15": "Nc",
45+
"16": "Ncd",
46+
"17": "Nd",
47+
"18": "Nep",
48+
"19": "Neqa",
49+
"20": "Neqb",
50+
"21": "Nes",
51+
"22": "Neu",
52+
"23": "Nf",
53+
"24": "Ng",
54+
"25": "Nh",
55+
"26": "Nv",
56+
"27": "P",
57+
"28": "T",
58+
"29": "VA",
59+
"30": "VAC",
60+
"31": "VB",
61+
"32": "VC",
62+
"33": "VCL",
63+
"34": "VD",
64+
"35": "VF",
65+
"36": "VE",
66+
"37": "VG",
67+
"38": "VH",
68+
"39": "VHC",
69+
"40": "VI",
70+
"41": "VJ",
71+
"42": "VK",
72+
"43": "VL",
73+
"44": "V_2",
74+
"45": "DE",
75+
"46": "SHI",
76+
"47": "FW",
77+
"48": "COLONCATEGORY",
78+
"49": "COMMACATEGORY",
79+
"50": "DASHCATEGORY",
80+
"51": "DOTCATEGORY",
81+
"52": "ETCCATEGORY",
82+
"53": "EXCLAMATIONCATEGORY",
83+
"54": "PARENTHESISCATEGORY",
84+
"55": "PAUSECATEGORY",
85+
"56": "PERIODCATEGORY",
86+
"57": "QUESTIONCATEGORY",
87+
"58": "SEMICOLONCATEGORY",
88+
"59": "SPCHANGECATEGORY"
89+
}
90+
for t, s in zip(tokenized_text, score[0][1:-1]):
91+
index = paddle.argmax(s).item()
92+
label = id2label[str(index)]
93+
print(f"{label} {t} score {s[index].item()}")
94+
95+
# Nb 傅 score 0.9999998807907104
96+
# Nb 達 score 0.9700667858123779
97+
# Na 仁 score 0.9985846281051636
98+
# Nd 今 score 0.9999947547912598
99+
# D 將 score 0.9999957084655762
100+
# VC 執 score 0.9999998807907104
101+
# VC 行 score 0.9951109290122986
102+
# Na 安 score 0.9999996423721313
103+
# Na 樂 score 0.9999638795852661
104+
# VH 死 score 0.9813857674598694
105+
# COMMACATEGORY , score 1.0
106+
# D 卻 score 1.0
107+
# D 突 score 1.0
108+
# Cbb 然 score 0.9989008903503418
109+
# VJ 爆 score 0.9999979734420776
110+
# VC 出 score 0.9965670108795166
111+
# Nh 自 score 1.0
112+
# Nh 己 score 1.0
113+
# Neu 20 score 0.9999995231628418
114+
# Nf 年 score 0.9125530123710632
115+
# Ng 前 score 0.9999992847442627
116+
# P 遭 score 1.0
117+
# Nb 緯 score 0.9999996423721313
118+
# VA 來 score 0.9322434663772583
119+
# Na 體 score 0.9846553802490234
120+
# Nc 育 score 0.729569137096405
121+
# Nc 台 score 0.9999841451644897
122+
# VC 封 score 0.9999997615814209
123+
# VC 殺 score 0.9999991655349731
124+
# COMMACATEGORY , score 1.0
125+
# Nh 他 score 0.9999996423721313
126+
# D 不 score 1.0
127+
# VK 懂 score 1.0
128+
# Nh 自 score 1.0
129+
# Nh 己 score 0.9999978542327881
130+
# Ncd 哪 score 0.9856181740760803
131+
# Ncd 裡 score 0.9999995231628418
132+
# VC 得 score 0.9999988079071045
133+
# Na 罪 score 0.9994786381721497
134+
# VCL 到 score 0.8332439661026001
135+
# Nc 電 score 1.0
136+
# Nc 視 score 0.9999986886978149
137+
# Nc 台 score 0.9973978996276855
138+
# PERIODCATEGORY 。 score 1.0
139+
140+
```
141+
142+
# 权重来源
143+
144+
https://huggingface.co/ckiplab/bert-base-chinese-pos
145+
这个项目提供了繁体中文版transformer模型(包含ALBERT、BERT、GPT2)及自然语言处理工具(包含分词、词性标注、命名实体识别)。
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/ckiplab-bert-base-chinese-pos/model_config.json",
3+
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/ckiplab-bert-base-chinese-pos/model_state.pdparams",
4+
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/ckiplab-bert-base-chinese-pos/tokenizer_config.json",
5+
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/ckiplab-bert-base-chinese-pos/vocab.txt"
6+
}

0 commit comments

Comments
 (0)