Skip to content

Commit d1190e5

Browse files
authored
Merge pull request #50 from fastnlp/fix_1.x
从huggingface modelhub加载模型源文件
2 parents 09550a7 + 6dd6280 commit d1190e5

File tree

7 files changed

+53
-44
lines changed

7 files changed

+53
-44
lines changed

.vscode/settings.json

Lines changed: 0 additions & 3 deletions
This file was deleted.

CHANGELOG.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
**版本更新:**
2+
- 1.1版本的fastHan与0.5.5版本的fastNLP会导致importerror。如果使用1.1版本的fastHan,请使用0.5.0版本的fastNLP。
3+
- 1.2版本的fastHan修复了fastNLP版本兼容问题。小于等于1.2版本的fastHan在输入句子的首尾包含**空格、换行**符时会产生BUG。如果字符串首尾包含上述字符,请使用 strip 函数处理输入字符串。
4+
- 1.3版本的fastHan自动对输入字符串做 strip 函数处理。
5+
- 1.4版本的fastHan加入用户词典功能(仅限于分词任务)
6+
- 1.5版本的fastHan
7+
- 修正了Parsing任务中可能会出现的ValueError
8+
- 修改结果的返回形式,默认以list的形式返回
9+
- 可以通过url路径加载模型
10+
- 1.6版本的fastHan
11+
- 将用户词典功能扩充到所有任务
12+
- 可以在返回值中包含位置信息
13+
- 1.7版本的fastHan
14+
- 添加finetune功能
15+
- 1.8
16+
- 改为从huggingface modelhub加载模型文件

README.md

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,10 @@ Zhichao Geng, Hang Yan, Xipeng Qiu and Xuanjing Huang, fastHan: A BERT-based Mul
2525

2626

2727
## 安装指南
28-
fastHan需要以下依赖的包:
29-
30-
- torch>=1.0.0
31-
- fastNLP>=0.5.5
32-
33-
**版本更新:**
34-
- 1.1版本的fastHan与0.5.5版本的fastNLP会导致importerror。如果使用1.1版本的fastHan,请使用0.5.0版本的fastNLP。
35-
- 1.2版本的fastHan修复了fastNLP版本兼容问题。小于等于1.2版本的fastHan在输入句子的首尾包含**空格、换行**符时会产生BUG。如果字符串首尾包含上述字符,请使用 strip 函数处理输入字符串。
36-
- 1.3版本的fastHan自动对输入字符串做 strip 函数处理。
37-
- 1.4版本的fastHan加入用户词典功能(仅限于分词任务)
38-
- 1.5版本的fastHan
39-
- 修正了Parsing任务中可能会出现的ValueError
40-
- 修改结果的返回形式,默认以list的形式返回
41-
- 可以通过url路径加载模型
42-
- 1.6版本的fastHan
43-
- 将用户词典功能扩充到所有任务
44-
- 可以在返回值中包含位置信息
45-
- 1.7版本的fastHan
46-
- 添加finetune功能
47-
4828
可执行如下命令完成安装:
4929

5030
```
51-
pip install fastHan
31+
pip install fastHan==1.8
5232
```
5333

5434
或者可以通过github安装:

fastHan/FastModel.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33
from shutil import copyfile
44

55
import torch
6-
from fastNLP import Trainer, Vocabulary
6+
from fastNLP import Trainer
77
from fastNLP.core.optimizer import AdamW
8-
from fastNLP.io.file_utils import cached_path
98

109
from .model.bert import BertEmbedding
1110
from .model.finetune_dataloader import (fastHan_CWS_Loader, fastHan_NER_Loader,
1211
fastHan_Parsing_Loader,
1312
fastHan_POS_loader)
1413
from .model.model import CharModel
1514
from .model.UserDict import UserDict
16-
15+
from .model.utils import hf_cached_path
1716

1817
class Token(object):
1918
"""
@@ -85,7 +84,11 @@ class FastHan(object):
8584
FastHan类封装了基于BERT的深度学习联合模型CharModel,可处理CWS、POS、NER、dependency parsing四项任务,这\
8685
四项任务共享参数。
8786
"""
88-
87+
HF_URL_MAP = {
88+
"base": 'fdugzc/fasthan_base',
89+
"large": 'fdugzc/fasthan_large'
90+
}
91+
CACHE_SUB_DIR = "fasthan"
8992

9093
def __init__(self,model_type='base',url=None):
9194
"""
@@ -96,6 +99,9 @@ def __init__(self,model_type='base',url=None):
9699
97100
:param str url:默认为None,用户可通过此参数传入手动下载并解压后的目录路径。
98101
"""
102+
if model_type not in ["base","large"]:
103+
raise ValueError("model_type can only be base or large.")
104+
99105
self.device='cpu'
100106
#获取模型的目录/下载模型
101107
if url is not None:
@@ -287,19 +293,8 @@ def set_cws_style(self,corpus):
287293
corpus='CWS-'+corpus
288294
self.tag_map['CWS']=self.corpus_map[corpus]
289295

290-
def _get_model(self,model_type):
291-
292-
#首先检查本地目录中是否已缓存模型,若没有缓存则下载。
293-
294-
if model_type=='base':
295-
url='http://212.129.155.247/fasthan/fasthan_base.zip'
296-
elif model_type=='large':
297-
url='http://212.129.155.247/fasthan/fasthan_large.zip'
298-
else:
299-
raise ValueError("model_type can only be base or large.")
300-
301-
model_dir=cached_path(url,name='fasthan')
302-
return model_dir
296+
def _get_model(self, model_type):
297+
return hf_cached_path(FastHan.HF_URL_MAP[model_type], FastHan.CACHE_SUB_DIR)
303298

304299
def _to_tensor(self,chars,target,seq_len):
305300

fastHan/model/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
22

3+
from pathlib import Path
34
from typing import Union, Dict
5+
from fastNLP.io.file_utils import get_cache_path, unzip_file
6+
from transformers.utils import cached_file
47

58

69
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
@@ -69,3 +72,20 @@ def get_tokenizer():
6972
except Exception as e:
7073
print('use raw tokenizer')
7174
return lambda x: x.split()
75+
76+
# 返回本地缓存的模型目录路径
77+
# 若本地无缓存,从huggingface中下载并解压
78+
# 修改自 fastNLP.io.file_utils.cached_path, transformers.utils.cached_file
79+
def hf_cached_path(model_url: str, cache_sub_dir: str):
80+
cache_dir = os.path.join(Path(get_cache_path()), cache_sub_dir)
81+
os.makedirs(cache_dir, exist_ok=True)
82+
83+
# model_name 为 fasthan_base 或 fasthan_large
84+
model_name = model_url.split("/")[-1]
85+
target_path = os.path.join(cache_dir, model_name)
86+
87+
if model_name not in os.listdir(cache_dir):
88+
# 若本地不存在缓存, 从huggingface中下载
89+
zipped_file = cached_file(model_url, model_name+".zip")
90+
unzip_file(zipped_file, cache_dir)
91+
return target_path

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
torch>=1.0.0
2-
FastNLP>=0.5.5
1+
torch>=1.0.0, <2.0.0
2+
FastNLP>=0.5.5, <1.0.0
3+
transformers >=4.0.0, <=4.35.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
setup(
1616
name='fastHan',
17-
version='1.7',
17+
version='1.8',
1818
url='https://github.com/fastnlp/fastHan',
1919
description=(
2020
'使用深度学习联合模型,解决中文分词、词性标注、依存分析、命名实体识别任务。'

0 commit comments

Comments
 (0)