1
1
import sys
2
- from typing import Any
3
2
from io import BytesIO
4
3
from pathlib import Path
5
4
import srsly
15
14
from transformers import AutoModel , AutoConfig , AutoTokenizer
16
15
17
16
17
+ def override_hf_shims_to_bytes ():
18
+ assert hf_shim .HFShim .to_bytes is not HFShimCustom .to_bytes
19
+ origin = hf_shim .HFShim .to_bytes
20
+ hf_shim .HFShim .to_bytes = HFShimCustom .to_bytes
21
+ return origin
22
+
23
+ def recover_hf_shims_to_bytes (origin ):
24
+ assert hf_shim .HFShim .to_bytes is HFShimCustom .to_bytes
25
+ hf_shim .HFShim .to_bytes = origin
26
+
27
+
18
28
def override_hf_shims_from_bytes ():
19
29
assert hf_shim .HFShim .from_bytes is not HFShimCustom .from_bytes
20
30
origin = hf_shim .HFShim .from_bytes
@@ -28,6 +38,44 @@ def recover_hf_shims_from_bytes(origin):
28
38
29
39
class HFShimCustom (HFShim ):
30
40
41
+ def to_bytes (self ):
42
+ config = {}
43
+ tok_dict = {}
44
+ # weights_bytes = {}
45
+ tok_cfg = {}
46
+ trf_cfg = {}
47
+ hf_model = self ._hfmodel
48
+ if hf_model .transformer is not None :
49
+ tok_dict = {}
50
+ config = hf_model .transformer .config .to_dict ()
51
+ tokenizer = hf_model .tokenizer
52
+ with make_tempdir () as temp_dir :
53
+ if hasattr (tokenizer , "vocab_file" ):
54
+ vocab_file_name = tokenizer .vocab_files_names ["vocab_file" ]
55
+ vocab_file_path = str ((temp_dir / vocab_file_name ).absolute ())
56
+ with open (vocab_file_path , "wb" ) as fileh :
57
+ fileh .write (hf_model .vocab_file_contents )
58
+ tokenizer .vocab_file = vocab_file_path
59
+ tokenizer .save_pretrained (str (temp_dir .absolute ()))
60
+ for x in temp_dir .glob ("**/*" ):
61
+ if x .is_file ():
62
+ tok_dict [x .name ] = x .read_bytes ()
63
+ filelike = BytesIO ()
64
+ torch .save (self ._model .state_dict (), filelike )
65
+ filelike .seek (0 )
66
+ # weights_bytes = filelike.getvalue()
67
+ else :
68
+ tok_cfg = hf_model ._init_tokenizer_config
69
+ trf_cfg = hf_model ._init_transformer_config
70
+ msg = {
71
+ "config" : config ,
72
+ # "state": weights_bytes,
73
+ "tokenizer" : tok_dict ,
74
+ "_init_tokenizer_config" : tok_cfg ,
75
+ "_init_transformer_config" : trf_cfg ,
76
+ }
77
+ return srsly .msgpack_dumps (msg )
78
+
31
79
def from_bytes (self , bytes_data ):
32
80
msg = srsly .msgpack_loads (bytes_data )
33
81
config_dict = msg ["config" ]
@@ -62,34 +110,35 @@ def from_bytes(self, bytes_data):
62
110
with open (vocab_file_path , "rb" ) as fileh :
63
111
vocab_file_contents = fileh .read ()
64
112
65
- try :
113
+ ops = get_current_ops ()
114
+ if ops .device_type == "cpu" :
115
+ map_location = "cpu"
116
+ else : # pragma: no cover
117
+ device_id = torch .cuda .current_device ()
118
+ map_location = f"cuda:{ device_id } "
119
+
120
+ if "state" in msg :
66
121
transformer = AutoModel .from_config (config )
67
- except OSError as e :
122
+ filelike = BytesIO (msg ["state" ])
123
+ filelike .seek (0 )
124
+ transformer .load_state_dict (torch .load (filelike , map_location = map_location ))
125
+ else :
68
126
try :
69
- transformer = AutoModel .from_pretrained (config [ " _name_or_path" ] , local_files_only = True )
127
+ transformer = AutoModel .from_pretrained (config . _name_or_path , local_files_only = True )
70
128
except OSError as e2 :
71
- print ("trying to download model from huggingface hub:" , config [ " _name_or_path" ] , "..." , file = sys .stderr )
72
- transformer = AutoModel .from_pretrained (config [ " _name_or_path" ] )
129
+ print ("trying to download model from huggingface hub:" , config . _name_or_path , "..." , file = sys .stderr )
130
+ transformer = AutoModel .from_pretrained (config . _name_or_path )
73
131
print ("succeded" , file = sys .stderr )
74
132
133
+ transformer .to (map_location )
134
+ self ._model = transformer
75
135
self ._hfmodel = HFObjects (
76
136
tokenizer ,
77
137
transformer ,
78
138
vocab_file_contents ,
79
139
SimpleFrozenDict (),
80
140
SimpleFrozenDict (),
81
141
)
82
- self ._model = transformer
83
- filelike = BytesIO (msg ["state" ])
84
- filelike .seek (0 )
85
- ops = get_current_ops ()
86
- if ops .device_type == "cpu" :
87
- map_location = "cpu"
88
- else : # pragma: no cover
89
- device_id = torch .cuda .current_device ()
90
- map_location = f"cuda:{ device_id } "
91
- self ._model .load_state_dict (torch .load (filelike , map_location = map_location ))
92
- self ._model .to (map_location )
93
142
else :
94
143
self ._hfmodel = HFObjects (
95
144
None ,
0 commit comments