Skip to content

Commit 7a2c396

Browse files
committed
fix(nlp/bert): update bert README and add a script
1 parent 3001b1c commit 7a2c396

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

pytorch/nlp/bert/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
1. Prepare model
55
```
6+
apt-get update
7+
apt-get install git-lfs
68
git lfs install
79
git clone https://huggingface.co/google-bert/bert-base-chinese
810
@@ -32,8 +34,9 @@ bash run_train.sh
3234
bash run_dist_train.sh
3335
```
3436

35-
5. Inference
37+
5. Model Consistency Check
3638
```shell
39+
# ⚠️ Make sure the model_path in test_bert.py is correctly set before running
3740
cp -r test_bert.py bert4torch/test/models/
3841
python bert4torch/test/models/test_bert.py
3942
```

pytorch/nlp/bert/test_bert.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
import torch_musa
3+
from bert4torch.models import build_transformer_model
4+
from bert4torch.tokenizers import Tokenizer
5+
from transformers import BertConfig, BertTokenizer, BertModel
6+
import os
7+
8+
9+
device = 'musa' if torch.cuda.is_available() else 'cpu'
10+
11+
def get_bert4torch_model(model_dir):
12+
config_path = model_dir + "/bert4torch_config.json"
13+
if not os.path.exists(config_path):
14+
config_path = model_dir + "/config.json"
15+
checkpoint_path = model_dir + '/pytorch_model.bin'
16+
17+
model = build_transformer_model(config_path, checkpoint_path) # 建立模型,加载权重
18+
return model.to(device)
19+
20+
21+
def get_hf_model(model_dir):
22+
tokenizer = BertTokenizer.from_pretrained(model_dir)
23+
model = BertModel.from_pretrained(model_dir)
24+
return model.to(device), tokenizer
25+
26+
27+
@pytest.mark.parametrize("model_dir", ["E:/data/pretrain_ckpt/bert/google@bert-base-chinese",
28+
"E:/data/pretrain_ckpt/bert/bert-base-multilingual-cased",
29+
"E:/data/pretrain_ckpt/bert/hfl@macbert-base",
30+
"E:/data/pretrain_ckpt/bert/hfl@chinese-bert-wwm-ext"])
31+
@torch.inference_mode()
32+
def test_bert(model_dir):
33+
model = get_bert4torch_model(model_dir)
34+
model_hf, tokenizer = get_hf_model(model_dir)
35+
36+
model.eval()
37+
model_hf.eval()
38+
39+
inputs = tokenizer('语言模型', padding=True, return_tensors='pt').to(device)
40+
sequence_output = model(**inputs)
41+
sequence_output_hf = model_hf(**inputs).last_hidden_state
42+
print(f"Output mean diff: {(sequence_output - sequence_output_hf).abs().mean().item()}")
43+
44+
assert (sequence_output - sequence_output_hf).abs().max().item() < 1e-4
45+
46+
47+
if __name__=='__main__':
48+
test_bert("/data/bert-base-chinese/")

0 commit comments

Comments
 (0)