3
3
from shutil import copyfile
4
4
5
5
import torch
6
- from fastNLP import Trainer , Vocabulary
6
+ from fastNLP import Trainer
7
7
from fastNLP .core .optimizer import AdamW
8
- from fastNLP .io .file_utils import cached_path
9
8
10
9
from .model .bert import BertEmbedding
11
10
from .model .finetune_dataloader import (fastHan_CWS_Loader , fastHan_NER_Loader ,
12
11
fastHan_Parsing_Loader ,
13
12
fastHan_POS_loader )
14
13
from .model .model import CharModel
15
14
from .model .UserDict import UserDict
16
-
15
+ from . model . utils import hf_cached_path
17
16
18
17
class Token (object ):
19
18
"""
@@ -85,7 +84,11 @@ class FastHan(object):
85
84
FastHan类封装了基于BERT的深度学习联合模型CharModel,可处理CWS、POS、NER、dependency parsing四项任务,这\
86
85
四项任务共享参数。
87
86
"""
88
-
87
+ HF_URL_MAP = {
88
+ "base" : 'fdugzc/fasthan_base' ,
89
+ "large" : 'fdugzc/fasthan_large'
90
+ }
91
+ CACHE_SUB_DIR = "fasthan"
89
92
90
93
def __init__ (self ,model_type = 'base' ,url = None ):
91
94
"""
@@ -96,6 +99,9 @@ def __init__(self,model_type='base',url=None):
96
99
97
100
:param str url:默认为None,用户可通过此参数传入手动下载并解压后的目录路径。
98
101
"""
102
+ if model_type not in ["base" ,"large" ]:
103
+ raise ValueError ("model_type can only be base or large." )
104
+
99
105
self .device = 'cpu'
100
106
#获取模型的目录/下载模型
101
107
if url is not None :
@@ -287,19 +293,8 @@ def set_cws_style(self,corpus):
287
293
corpus = 'CWS-' + corpus
288
294
self .tag_map ['CWS' ]= self .corpus_map [corpus ]
289
295
290
- def _get_model (self ,model_type ):
291
-
292
- #首先检查本地目录中是否已缓存模型,若没有缓存则下载。
293
-
294
- if model_type == 'base' :
295
- url = 'http://212.129.155.247/fasthan/fasthan_base.zip'
296
- elif model_type == 'large' :
297
- url = 'http://212.129.155.247/fasthan/fasthan_large.zip'
298
- else :
299
- raise ValueError ("model_type can only be base or large." )
300
-
301
- model_dir = cached_path (url ,name = 'fasthan' )
302
- return model_dir
296
+ def _get_model (self , model_type ):
297
+ return hf_cached_path (FastHan .HF_URL_MAP [model_type ], FastHan .CACHE_SUB_DIR )
303
298
304
299
def _to_tensor (self ,chars ,target ,seq_len ):
305
300
0 commit comments